K-means算法的原理與實現

K-Means算法的思想很簡單,對於給定的樣本集,按照樣本之間的距離大小,將樣本集劃分爲K個簇。讓簇內的點儘量緊密的連在一起,而讓簇間的距離儘量的大。

    如果用數據表達式表示,假設簇劃分爲(C1,C2,...Ck)(C1,C2,...Ck),則我們的目標是最小化平方誤差E:

E=∑i=1k∑x∈Ci||x−μi||22E=∑i=1k∑x∈Ci||x−μi||22

    其中μiμi是簇CiCi的均值向量,有時也稱爲質心,表達式爲:

μi=1|Ci|∑x∈Cixμi=1|Ci|∑x∈Cix

    如果我們想直接求上式的最小值並不容易,這是一個NP難的問題,因此只能採用啓發式的迭代方法。

    K-Means採用的啓發式方式很簡單,用下面一組圖就可以形象的描述。

    上圖a表達了初始的數據集,假設k=2。在圖b中,我們隨機選擇了兩個k類所對應的類別質心,即圖中的紅色質心和藍色質心,然後分別求樣本中所有點到這兩個質心的距離,並標記每個樣本的類別爲和該樣本距離最小的質心的類別,如圖c所示,經過計算樣本和紅色質心和藍色質心的距離,我們得到了所有樣本點的第一輪迭代後的類別。此時我們對我們當前標記爲紅色和藍色的點分別求其新的質心,如圖4所示,新的紅色質心和藍色質心的位置已經發生了變動。圖e和圖f重複了我們在圖c和圖d的過程,即將所有點的類別標記爲距離最近的質心的類別並求新的質心。最終我們得到的兩個類別如圖f。

    當然在實際K-Mean算法中,我們一般會多次運行圖c和圖d,才能達到最終的比較優的類別。

附一張表格,簡單實現一下給大家

代碼如下

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

data = pd.read_csv(r'd:\Desktop\shuju\company.csv', engine='python', encoding='gbk')
print(data.columns)

# 設k=3 分爲類
# 不考慮年齡
# 各類中心點的座標爲
# category1 = [100, 10]
# category2 = [200, 20]
# category3 = [300, 30]

x = data['平均每次消費金額']
y = data['平均消費週期(天)']


def KMeans(x11, y11, x22, y22, x33, y33, ):
    data['distance1'] = np.sqrt((x11 - x) ** 2 + (y11 - y) ** 2)
    data['distance2'] = np.sqrt((x22 - x) ** 2 + (y22 - y) ** 2)
    data['distance3'] = np.sqrt((x33 - x) ** 2 + (y33 - y) ** 2)
    # print(data)

    # 將d1 d2 d3 取出來形成test表
    test = data[['distance1', 'distance2', 'distance3']]
    # print(test)

    # 找到每列最小值的所在的columns
    min_zxz = test.idxmin(axis=1)
    # 將返回的最小值的列索引的值加到表中。
    data['addr_min'] = min_zxz.values

    # print(data)
    # 根據分組求出各個columns的均值 重新賦值   x,y,z

    df = data[['平均每次消費金額', '平均消費週期(天)', 'addr_min']].groupby(by='addr_min').mean()

    x1 = df['平均每次消費金額']['distance1']
    x2 = df['平均每次消費金額']['distance2']
    x3 = df['平均每次消費金額']['distance3']
    y1 = df['平均消費週期(天)']['distance1']
    y2 = df['平均消費週期(天)']['distance2']
    y3 = df['平均消費週期(天)']['distance3']

    # 遞歸出口
    if x11 == x1 and y11 == y1 and x22 == x2 and y22 == y2 and x33 == x3 and y33 == y3:
        print(data)
        return x1, y1, x2, y2, x3, y3
    # print('++++++++++++++++++++++++++++++++++')
    return KMeans(x1, y1, x2, y2, x3, y3, )


jg = KMeans(100, 10, 200, 20, 300, 30)

print(jg)

mask1 = data['addr_min'] == 'distance1'
mask2 = data['addr_min'] == 'distance2'
mask3 = data['addr_min'] == 'distance3'
index1 = data.index[mask1]
data1 = data[['平均每次消費金額', '平均消費週期(天)']][mask1]
data2 = data[['平均每次消費金額', '平均消費週期(天)']][mask2]
data3 = data[['平均每次消費金額', '平均消費週期(天)']][mask3]
plt.figure()
plt.scatter(data1['平均消費週期(天)'], data1['平均每次消費金額'])
plt.scatter(data2['平均消費週期(天)'], data2['平均每次消費金額'])
plt.scatter(data3['平均消費週期(天)'], data3['平均每次消費金額'])
plt.show()

 

 

K-Means是個簡單實用的聚類算法,這裏對K-Means的優缺點做一個總結。

    K-Means的主要優點有:

    1)原理比較簡單,實現也是很容易,收斂速度快。

    2)聚類效果較優。

    3)算法的可解釋度比較強。

    4)主要需要調參的參數僅僅是簇數k。

    K-Means的主要缺點有:

    1)K值的選取不好把握

    2)對於不是凸的數據集比較難收斂

    3)如果各隱含類別的數據不平衡,比如各隱含類別的數據量嚴重失衡,或者各隱含類別的方差不同,則聚類效果不佳。

    4) 採用迭代方法,得到的結果只是局部最優。

    5) 對噪音和異常點比較的敏感。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章