手寫K-means及K-means++

手寫K-means及K-means++經典算法及實戰

前段時間在學校,看了一篇關於K-means-u的聚類論文,當時對聚類只是聽過,但對許多經典算法和練習都不夠,所以今天專門記錄一下,當然也查閱了網上許多資料,如果本文哪有紕漏,歡迎各位的批評和建議

關於K-means和K-means++的算法流程,我這裏就不細講了,之前做過一個PPT,點擊下方鏈接即可查看

https://slides.com/huozhang/clusting/fullscreen

K-means

# 手寫k-means算法

# 導入必要的庫
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


def distance(dec1, dec2) -> float:
    return np.sqrt(np.sum(np.square(dec1 - dec2)))  # 計算歐式距離


def K_Means(data, k):
    K = np.random.uniform(0, 10, (k, data.shape[1]))  # 初始化中心點位矩陣
    ret = np.zeros([data.shape[0], data.shape[1]])  # 構造一個答案矩陣
    flag = True  # 定義標記變量
    count = 1
    while flag:

        flag = False
        for i in range(data.shape[0]):
            minIndex = -1  # 定義得到最短距離的時候的臨時中心點位
            minDis = np.inf  # 定義最短距離
            for j in range(K.shape[0]):
                dis = distance(data[i], K[j])  # 計算距離
                if dis < minDis:
                    minDis = dis
                    minIndex = j

            # 計算完成後 將所屬距離點位填入答案矩陣
            ret[i][0] = minDis
            ret[i][1] = minIndex

        # 計算好位置了,現在開始重新計算均值,查詢中心點
        for i in range(k):
            cluster = data[ret[:, 1] == i]
            # print("中心點位——————", cluster)
            if len(cluster) == 0:
                pass
            else:
                center = np.mean(cluster, axis=0)
                # print("center 爲", center)
                # print("K[i] 爲", K[i])
                if (center == K[i]).all():  # 這裏必須用all()因爲 ndarry()的性質 具體大家可以百度
                    pass
                else:
                    K[i] = center
                    flag = True

        _X = K[:, 0]
        _Y = K[:, 1]
        plt.scatter(_x, _y)
        plt.scatter(_X, _Y, marker='X', color='r')
        plt.title("%d  of convergence" % count)
        count += 1
        plt.show()


if __name__ == '__main__':
    # 構造三簇數據
    data1 = np.random.uniform(0, 2, (10, 2))
    data2 = np.random.uniform(3, 6, (10, 2))
    data3 = np.random.uniform(8, 10, (10, 2))
    data = np.r_[data1, data2, data3]  # 按列上下合併數據  np.c_[]是按行 左右合併

    _x = data[:, 0]
    _y = data[:, 1]
    plt.rcParams['font.sans-serif'] = 'SimHei'
    # plt.scatter(_x, _y, marker="o")
    # plt.show()
    K_Means(data, 3)  # 自定義定義3箇中心

代碼進行了可視化展示,可以看看每次收斂效果


K-means++

主要解決K-means初始點位選擇問題,k-means++即基本解決了這一問題,點位選擇完成過後,進行相應的收斂和迭代,以減少迭代次數和SSE(最大平方誤差和)

# 手寫k-means++算法

# 導入必要的庫
import numpy as np
import matplotlib.pyplot as plt
import random
import math


def euler_distance(point1, point2) -> float:
    """
    計算兩點之間的歐拉距離,支持多維
    """
    distance = 0.0
    for a, b in zip(point1, point2):  # 將x,y對應成元組進行計算
        distance += math.pow(a - b, 2)
    return math.sqrt(distance)


def distance(point, cluster) -> float:
    min_dist = math.inf
    for i, centroid in enumerate(cluster):
        dist = euler_distance(centroid, point)  # 計算每一個點到每個簇的距離
        if dist < min_dist:
            min_dist = dist
    return min_dist  # 返回離簇最近的距離


def KMeansplus(data, k) -> list:
    """
    K-means++ 主要解決K-means的初始點位選擇問題,返回點位後,再進行收斂,此函數僅完成簇點選取
    :param data: 數據集(測試集)
    :param k: 簇個數
    :return: 返回簇x,y列表
    """
    cluster_list = []
    cluster_list.append(random.choice(data).tolist())
    ret = [0 for _ in range(len(data))]  # 構造距離空列表
    # print(ret)
    for _ in range(1, k):  # 這裏從第二個點開始,因爲第一個點是隨機的
        sum_dis = 0
        for i, point in enumerate(data):
            ret[i] = distance(point, cluster_list)
            sum_dis += ret[i]  # 累加距離
        sum_dis *= random.random()  # 利用輪盤法*[0~1]裏面的隨機數
        for i, point_dis in enumerate(ret):
            sum_dis -= point_dis  # 依次減距離
            if sum_dis <= 0:
                cluster_list.append(data[i].tolist())  # 直到sum_dis爲0時 將此點作爲第二個點
                break

    return cluster_list  


if __name__ == '__main__':
    # 構造三簇數據
    data1 = np.random.uniform(0, 2, (20, 2))
    data2 = np.random.uniform(3, 6, (20, 2))
    data3 = np.random.uniform(8, 10, (20, 2))
    data = np.r_[data1, data2, data3]  # 按列上下合併數據  np.c_[]是按行 左右合併

    np.random.shuffle(data)

    _x = data[:, 0]
    _y = data[:, 1]
    plt.rcParams['font.sans-serif'] = 'SimHei'
    plt.scatter(_x, _y, marker="o")

    center = KMeansplus(data, 3)  # 自定義定義3箇中心
    # print("k-means++ 的中心爲\n", center)
    center = np.array(center)
    center_x = center[:, 0]
    center_y = center[:, 1]
    plt.scatter(center_x, center_y, marker='X', c='r')
    plt.show()
    

可以看見在點位選擇上基本和以完成收斂的點位非常接近,之後再通過幾次收斂達到最佳

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章