k-means算法實現python

import numpy as np
import matplotlib.pyplot as plt

# 兩點距離
def distance(e1, e2):
    return np.sqrt((e1[0]-e2[0])**2+(e1[1]-e2[1])**2)

# 集合中心
def means(arr):
    return np.array([np.mean([e[0] for e in arr]), np.mean([e[1] for e in arr])])

# arr中距離a最遠的元素,用於初始化聚類中心
def farthest(k_arr, arr):
    f = [0, 0]
    max_d = 0
    for e in arr:
        d = 0
        for i in range(k_arr.__len__()):
            d = d + np.sqrt(distance(k_arr[i], e))
        if d > max_d:
            max_d = d
            f = e
    return f

# arr中距離a最近的元素,用於聚類
def closest(a, arr):
    c = arr[1]
    min_d = distance(a, arr[1])
    arr = arr[1:]
    for e in arr:
        d = distance(a, e)
        if d < min_d:
            min_d = d
            c = e
    return c


if __name__=="__main__":
    ## 生成二維隨機座標(如果有數據集就更好)
    arr = np.random.randint(100, size=(100, 1, 2))[:, 0, :]  # 取所有數據的第0列到最後一列數據,100是上限,默認下限爲0
    # print(arr)

    ## 初始化聚類中心和聚類容器
    m = 5
    r = np.random.randint(arr.__len__() - 1)
    k_arr = np.array([arr[r]])    # 取到隨機一個數據
    print(k_arr)
    cla_arr = [[]]
    for i in range(m-1):
        k = farthest(k_arr, arr)
        # print("k:",k)
        k_arr = np.concatenate([k_arr, np.array([k])])  # 豎着拼接這兩個元素
        # print("k_arr",k_arr)
        cla_arr.append([])
        # print("cla:",cla_arr)

    ## 迭代聚類
    n = 20
    cla_temp = cla_arr
    for i in range(n):    # 迭代n次
        for e in arr:    # 把集合裏每一個元素聚到最近的類
            ki = 0        # 假定距離第一個中心最近
            min_d = distance(e, k_arr[ki])
            for j in range(1, k_arr.__len__()):
                if distance(e, k_arr[j]) < min_d:    # 找到更近的聚類中心
                    min_d = distance(e, k_arr[j])
                    ki = j
            cla_temp[ki].append(e)
        # 迭代更新聚類中心
        for k in range(k_arr.__len__()):
            if n - 1 == i:
                break
            k_arr[k] = means(cla_temp[k])
            cla_temp[k] = []
        print(cla_temp)

    ## 可視化展示
    col = ['HotPink', 'Aqua', 'Chartreuse', 'yellow', 'LightSalmon']
    for i in range(m):
        plt.scatter(k_arr[i][0], k_arr[i][1], linewidth=10, color=col[i])
        plt.scatter([e[0] for e in cla_temp[i]], [e[1] for e in cla_temp[i]], color=col[i])
    plt.show()

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章