Python:獲取K-Means中心點最近的樣本

import numpy as np
import pandas as pd
from sklearn import datasets as DS
import matplotlib.pyplot as plt


def euclideanDist(A, B):
    return np.sqrt(sum((A - B) ** 2))
def RandomCenters(dataSet, k):
    n = dataSet.shape[0]
    centerIndex = np.random.choice(range(n), size=k, replace=False)
    centers = dataSet[centerIndex]
    return centers
def KMeans(dataSet, k):
    Centers = RandomCenters(dataSet, k)
    n, m = dataSet.shape
    DistMatrix = np.zeros((n, 2))  #n*2的矩陣用於存儲 類簇索引
    centerChanged = True
    while centerChanged == True:
        centerChanged = False
        for i in range(n):
            minDist = np.inf
            minIndex = -1
            for j in range(k):
                dist = euclideanDist(dataSet[i, :], Centers[j, :])
                if dist < minDist:    #獲取每個樣本聚類最近的聚類中心點及其聚類
                    minDist = dist
                    minIndex = j
            if DistMatrix[i, 0] != minIndex:
                centerChanged = True
            DistMatrix[i, 0] = minIndex   #存儲的是索引
            DistMatrix[i, 1] = minDist    #存儲的是距離
        if centerChanged == True:  # 如何聚類中心有變化,那麼接下來就要更新聚類中心
            for i in range(k):
                dataMean = dataSet[DistMatrix[:, 0] == i]  # dataMean中是相同類簇的樣本
                Centers[i] = np.mean(dataMean, axis=0)
    return Centers, DistMatrix

def PointSelection(DistMatrix,k,n):
    points = []
    for i in range(k):
        minDist = np.inf
        closeIndex = -1
        for j in range(n):
            if DistMatrix[j,0] == i:
                if DistMatrix[j,1] < minDist:
                    minDist = DistMatrix[j,1]
                    closeIndex = j
        points.append(closeIndex)
    return points

if __name__ == "__main__":
    path = r"D:\dataset\clusterData\bolbs_1.csv"
    Data = np.array(pd.read_csv(path, header=None))
    X = Data[:, :2]
    n = len(X)
    k = 2
    Center, DistMat = KMeans(X, k)
    Points = PointSelection(DistMat,k,n)
    plt.scatter(X[:,0],X[:,1], c=DistMat[:,0] )
    CP = X[Points]
    plt.scatter(CP[:,0],CP[:,1],marker="*",s=200)
    plt.show()

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章