K-means
優缺點
優點:容易實現
缺點:可能收斂到局部最小值,在大規模數據集上收斂較慢
步驟
- 選擇K個點作爲初始聚類中心
- 計算其餘所有點到聚類中心的距離,並把每個點劃分到離它最近的聚類中心所在的聚類中去。計算距離常用歐幾里得距離公式,也叫歐氏距離。查看距離的計算方法
- 重新計算每個聚類中所有點的平均值,並將其作爲新的聚類中心點。
- 重複2、3步,直到聚類中心不再發生改變,或者算法達到預定的迭代次數,又或聚類中心的改變小於預先設定的閾值
使用後處理來提高聚類性能
局部最小值指結果還可以但並非最好結果,全局最小值是可能的最好結果
一種用於度量聚類效果的指標是SSE(誤差平方和)。SSE值越小表示數據點越接近它們的質心。
後處理方法:將具有最大SSE值的簇劃分成爲兩個簇;合併最近的質心,或者合併兩個使得SSE增幅最小的質心。
二分K-Means算法
- 將所有點作爲一個簇,然後將該簇一分爲二;
- 之後選擇其中一個簇繼續進行劃分,選擇哪一個簇進行劃分取決於對其劃分是否可以最大程度降低SSE的值。
- 不斷重複SSE的劃分過程,直到得到用戶指定的簇數目爲止。
Spark實現KMeans
關鍵步驟
聚類個數K的選擇
Spark MLlib 在 KMeansModel 類裏提供了 computeCost 方法,該方法通過計算所有數據點到其最近的中心點的平方和來評估聚類的效果。一般來說,同樣的迭代次數和算法跑的次數,這個值越小代表聚類的效果越好。但是在實際情況下,我們還要考慮到聚類結果的可解釋性,不能一味的選擇使 computeCost 結果值最小的那個 K。
初始聚類中心點的選擇
Spark MLlib K-means算法的實現在初始聚類點的選擇上,借鑑了K-means||的類K-means++實現。K-means++算法在初始點選擇上遵循一個基本原則:初始聚類中心點相互之間的距離應該儘可能的遠。基本步驟如下:
1. 從數據集中隨機選擇一個點作爲第一個初始點
2. 計算數據集中所有點與最新選擇的中心點的距離D(x)
3. 選擇下一個中心點,使得
最大
4. 重複2、3步,直到K個初始點選擇完成
MLlib的K-means實現
- 讀取訓練數據,調用KMeans.train方法對數據集進行聚類訓練,方法返回KMeansModel實例
- 使用KMeansModel.predict方法對新的數據點進行所屬聚類的預測
- 均方差(MSE),就是對各個實際存在評分的項,pow(預測評分-實際評分,2)的值進行累加,再除以項數。而均方根差(RMSE)就是MSE開根號。MSE/RMSE值越小說明預測結果越準確
參數
參數 | 含義 |
---|---|
k | 所需簇的數量。請注意,可以返回少於k個集羣,例如,如果有少於k個不同的集羣點 |
maxIterations | 運行的最大迭代次數。 |
initializationMode | 通過k-means ||指定隨機初始化或初始化。 |
initializationSteps | 確定k-means ||中的步數算法。 |
epsilon | 確定我們認爲k均值收斂的距離閾值。 |
initialModel | 用於初始化的可選集羣中心集。如果提供此參數,則僅執行一次運行。 |
Spark_K-Means_Python
from __future__ import print_function
from numpy import array
from math import sqrt
from pyspark import SparkContext
from pyspark.mllib.clustering import KMeans, KMeansModel
if __name__ == "__main__":
sc = SparkContext(appName="KmeansExample")
# Load and parse the data
data = sc.textFile("kmeans_data.txt")
parsedData = data.map(lambda line:array([float(x) for x in line.split(' ')]))
# Build the Model(cluster the data)
clusters = KMeans.train(parsedData, 2, maxIterations=10, initializationMode="random")
print(clusters.clusterCenters)
print(clusters.predict([0.2, 0.2, 0.2]))
# Evaluate clustering by computing Within Set Sum of Squared Errors
def error(point):
center = clusters.centers[clusters.predict(point)]
return sqrt(sum([x**2 for x in (point - center)]))
WSSSE = parsedData.map(lambda point: error(point)).reduce(lambda x, y: x + y)
print("Within Set Sum of Squared Error = " + str(WSSSE))