聚類是一個將數據集中在某些方面相似的數據成員進行分類組織的過程,聚類就是一種發現這種內在結構的技術,聚類技術經常被稱爲無監督學習。
k均值聚類是最著名的劃分聚類算法,由於簡潔和效率使得他成爲所有聚類算法中最廣泛使用的。給定一個數據點集合和需要的聚類數目k,k由用戶指定,k均值算法根據某個距離函數反覆把數據分入k個聚類中。(摘自百度百科)
假設對基本的二維平面上的點進行K均值聚類,其實現基本步驟是:
- 事先選定好K個聚類中心(假設要分爲K類)。
- 算出每一個點到這K個聚類中心的距離,然後把該點分配給距離它最近的一個聚類中心。
- 更新聚類中心。算出每一個類別裏面所有點的平均值,作爲新的聚類中心。
- 給定迭代此次數,不斷重複步驟2,3,達到該迭代次數後自動停止。
思想很簡單,實現起來也很簡單,附上代碼(有註釋):
import numpy as np
import matplotlib.pyplot as plt
#np.random.seed(300)
x=np.random.rand(200)*15 #產生要聚類的數據點,(0,15)之間
y=np.random.rand(200)*15
center_x=[] #存放聚類中心座標
center_y=[]
result_x=[] #存放每次迭代後每一小類的座標
result_y=[]
number_cluster=4 #簇數
time=50 #迭代次數
color=['red','blue','black','orange']
for i in range(number_cluster): # 隨機生成中心
result_x.append([]) #順便初始化存放聚類結果的列表
result_y.append([])
x1 = np.random.choice(x) #爲了避免出現聚類後有的簇一個點也沒有,
y1 = np.random.choice(y) #乾脆就以某一個數據點爲中心
if x1 not in center_x and y1 not in center_y:
center_x.append(x1)
center_y.append(y1)
plt.scatter(x,y) #畫出數據圖
plt.title('init plot')
plt.show()
def K_means():
for t in range(time):
for i in range(len(x)):
distance = [] #存放每個點到各中心的距離
for j in range(len(center_x)):
k = (center_x[j] - x[i]) ** 2 + (center_y[j] - y[i]) ** 2 #距離
distance.append([k])
result_x[distance.index(min(distance))].append(x[i]) #聚類
result_y[distance.index(min(distance))].append(y[i])
plt.title('iterations:'+str(t+1))
for i in range(number_cluster):
plt.scatter(result_x[i], result_y[i], c=color[i])
plt.show()
# 更新位置
center_x.clear()
center_y.clear()
for i in range(number_cluster):
ave_x = np.mean(result_x[i])
ave_y = np.mean(result_y[i])
center_x.append(ave_x)
center_y.append(ave_y)
if __name__=='__main__':
K_means()
結果展示:
1.初始化:
2.第一次迭代:
3.第二次迭代:
4.第九次迭代(收斂):