K-means原理與Python實現

k-means算法

K-均值聚類算法(k-means clustering algorithm)是一種無監督聚類算法。本文前部分介紹算法原理及優缺點,後面通過Python代碼實現一個簡版的k-means算法。

優缺點

  • 優點:簡潔快速,算法的關鍵在於初始中心的選擇和距離度量。
  • 缺點
    1. K值(聚類的數目)需要事先確定。
    2. 聚類結果對初始類中心的選取較爲敏感。
    3. 容易陷入局部最優。
    4. 只能發現球型簇
  • 時間複雜度O(tKmn)O(tKmn),其中,t爲迭代次數,K爲簇的數目,m爲記錄數,n爲維數。
  • 空間複雜度O((m+k)n)O((m+k)n),其中,K爲簇的數目,m爲記錄數,n爲維數。

算法流程

  • 1.隨機給定各簇中心 $ c_1, c_2,…, c_n $

  • 2.計算各樣本點x1,x2,...,xnx_1, x_2, ..., x_n到簇中心的距離(一般採用歐式距離),將樣本點歸類到距離最小的簇中,公式如下:

    yiargminxicy2,i=1,2,...,n y_i \leftarrow argmin\|x_i - c_y\|^2, i = 1,2,...,n

    ​ 其中,argminargmin是使目標函數取最小值時的變量值

  • 3.更新各簇中心$ c_1, c_2,…, c_n $

    ci=1nyi:yiyxi    y=1,2,...,c c_i = \frac{1}{ny}\sum_{i:y_i\in y}x_i \ \ \ \ 其中,y=1,2,...,c

    ​ 上式中,nyny爲簇y的樣本總數

  • 4.重複步驟2、3,直到達到收斂精度即可

停止條件

  • 達到預先設置的最大迭代次數。
  • 聚類中心不再發生變化。
  • 相鄰兩次聚類結果中差值變化小於某一個閾值。

K值的確定

  1. 經驗法: 在實際工作中,結合業務的場景和需求,來決定幾類確定K值

  2. 肘部法則: 肘部法則通過成本函數來刻畫的,其是通過將不同K值的成本函數刻畫出來,隨着K值的增大,平均畸變程度的改善效果會不斷降低。因此。在找出K值增大的過程中,畸變程度下降幅度最大的位置所對應的K較爲合理。(注:成本函數爲各類的畸變程度之和與其內部成員位置距離的平方和,最優解是成本函數最小化爲目標。公式表示爲
    J=k=1kickixick2 J=\sum_{k=1}^k\sum_{i\in c_k}^i|x_i-c_k|^2
    其中,ckc_k是第k個質心的位置。

  3. 規則法: k=n/2k=\sqrt{n/2}

K-means算法Python實現

下面代碼將k-means封裝成類,實例化對象後調用fit()和transform()方法實現對數據的訓練和轉化,talk is checp, show you the code!

import numpy as np
from collections import Counter
import matplotlib.pyplot as plt


class Kmeans():

    def __init__(self, n_clusters=8, max_iter=10, tol=1e-4):
        self.__n_cluster = n_clusters
        self.__max_iter = max_iter
        self.__tol = tol
        self.__cluster_centers_ = None
        self.__n_iter_ = 0

        self.__label = None

        self.info = ''

    @property
    def cluster_centers_(self):
        return self.__cluster_centers_

    @property
    def n_iter(self):
        return self.__n_iter_

    @property
    def label(self):
        return self.__label

    def __generate_cluster_centers(self, data):
        """隨機生成類中心"""
        m, n = data.shape
        col_max_value = data.max(axis=0)
        col_min_value = data.min(axis=0)

        cluster_centers = np.empty(shape=[self.__n_cluster, n])
        for col_index in range(n):
            col_value = np.random.uniform(col_min_value[col_index],
                                          col_max_value[col_index],
                                          self.__n_cluster)
            cluster_centers[:, col_index] = col_value
        return cluster_centers

    def __calu_distance(self, sample, cluster_center):
        """
        計算一個樣本sample到一個聚類中心的距離(歐式距離)
        :param sample: 輸入的一個樣本
        :return: 距離
        """
        return np.sqrt(np.sum(np.square(cluster_center - sample)))

    def __category(self, sample):
        """
        計算一個sample所屬的簇
        :param sample:
        :return: 簇索引
        """
        distance_list = []
        for cluster_center_index in range(self.__n_cluster):
            cluster_center = self.__cluster_centers_[cluster_center_index, :]
            distance_list.append(self.__calu_distance(sample=sample, cluster_center=cluster_center))

        return distance_list.index(min(distance_list))

    def __update_cluster_centers(self, data, label):
        """
        計算類的中心
        :param label:
        :return:
        """
        cluster_centers = np.empty(shape=[self.__n_cluster, data.shape[1]])
        for cluster_index in range(self.__n_cluster):
            sample_index_array = np.where(label == cluster_index)
            samples = data[sample_index_array]
            center = samples.mean(axis=0)
            cluster_centers[cluster_index] = center
        return cluster_centers

    def __is_terminate(self, cluster_centers, label):
        """
        是否滿足終止條件(1. 達到最大迭代次數; 2. 簇中心不再變化; 3. 兩次聚類結果差值小於某一個閾值
        :param label:
        :return:
        """
        diff = self.__calu_diff(label)
        if self.__n_iter_ >= self.__max_iter:
            self.info = '達到最大迭代次數: ' + str(self.__max_iter) + ',聚類結束。'
            return True

        if (self.__cluster_centers_ == cluster_centers).all():
            self.info = '兩次聚類簇中心未發生變化,聚類結束。'
            return True

        if diff <= self.__tol:
            self.info = '聚類結果差值爲 {},小於 {},聚類結束。'.format(diff, self.__tol)
            return True

        self.__label = label
        return False

    def __calu_diff(self, label):
        """
        計算兩次聚類結果的差值
        :param label:  最近的一次聚類結果
        :return: 差值佔比
        """
        ans = label == self.__label
        length = len(ans)
        ans_dict = Counter(ans)
        diff_num = ans_dict.get(False)
        if diff_num == None:
            return 0
        else:
            return diff_num / length

    def fit(self, data):
        """
        將數據分簇,並計算簇的中心
        :param data:  數據
        :return:
        """
        label = np.array([np.nan] * data.shape[0])
        self.__cluster_centers_ = self.__generate_cluster_centers(data=data)
        while True:
            # 將每個樣本劃分到最近的簇中
            for sample_index in range(data.shape[0]):
                category = self.__category(sample=data[sample_index])
                label[sample_index] = category

            # 更新簇中心
            cluster_centers = self.__update_cluster_centers(data=data, label=label)
            self.__n_iter_ += 1
            # 檢查是否滿足終止條件
            if self.__is_terminate(cluster_centers=cluster_centers, label=label):
                self.__cluster_centers_ = cluster_centers
                self.print_info()
                break
            self.__cluster_centers_ = cluster_centers
            self.__label = label.copy()

    def transform(self, data):
        """
        計算每個數據到各聚類中心的距離
        :param data:
        :return:
        """
        ans = np.empty([data.shape[0], self.__n_cluster])
        for index in range(data.shape[0]):
            tmp = []
            for center in self.__cluster_centers_:
                tmp.append(self.__calu_distance(sample=data[index], cluster_center=center))
            center_array = np.array(tmp)
            ans[index] = center_array
        return ans

    def predict(self, data):
        """
        預測數據所屬的簇
        :param data:  待預測數據
        :return: 數據所屬類別
        """
        distiance = self.transform(data)
        cluster_labels = np.argmax(distiance, axis=1)
        return cluster_labels

    def print_info(self):
        print('clustering is terminate!!!')
        print('結束原因:{}'.format(self.info))
        print('迭代次數 {}'.format(self.__n_iter_))
        print("簇中心:\n", self.__cluster_centers_)

隨機生成正太分佈的數據

def generate_test_data(size=(120, 2), clusters=3, loc=10, scale=10):
    cluster_center = np.array([[10, 10], [50, 100], [100, 50]])
    data = np.empty(shape=[0, size[1]])
    for cluster_index in range(clusters):
        new_data_x = np.random.normal(loc=cluster_center[cluster_index][0],
                                      scale=scale, size=[int(size[0] / clusters),1])
        new_data_y = np.random.normal(loc=cluster_center[cluster_index][1],
                                      scale=scale, size=[int(size[0] / clusters), 1])
        new_data = np.concatenate((new_data_x, new_data_y), axis=1)
        data = np.concatenate([data, new_data], axis=0)
    # print(data.shape)
    return data

聚類結果可視化函數

def show_data(data, label, cluster_centers=[]):
    filled_markers = ('o', 'v', '^', '<', '>', '8', 's', 'p', 'h', 'H', 'D', 'd', 'P', 'X')
    filled_colors = ('black', 'green', 'darkorchid', 'darkred', 'darksalmon', 'darkslategray')

    plt.figure()
    for sample, sample_label in zip(data, label):
        plt.scatter(sample[0], sample[1], marker=filled_markers[sample_label],
                    color=filled_colors[sample_label], s=40, label='cluster_' + str(sample_label))

    if len(cluster_centers) > 0:
        for center in cluster_centers:
            plt.scatter(center[0], center[1], marker='*', color='red', s=80)
    # plt.legend(scatterpoints=1)
    plt.show()

主函數調用

if __name__ == '__main__':
	model = Kmeans(n_clusters=3, max_iter=10000)
    data = generate_test_data(scale=15)
    model.fit(data)
    label = model.label.astype('int')
    centers = model.cluster_centers_
    show_data(data, label, centers)

運行上面代碼段,輸出下面結果

clustering is terminate!!!
結束原因:兩次聚類簇中心未發生變化,聚類結束。
迭代次數 5
簇中心:
 [[100.46756833  48.52330939]
 [ 49.11200803 101.26172269]
 [  9.27754808   9.65780098]]

可視化結果

在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章