聚類
聚類是一種無監督的學習,它將相似的對象歸到同一個簇中。聚類的方法幾乎可以應用於所有對象,簇內的對象越相似,聚類的效果越好。K-means(K-均值聚類)算法使一種聚類算法。之所以稱之爲K-均值使因爲它可以發現k個不同的簇,且每個簇的中心採用簇中所含值的均值計算而成。
K-均值聚類算法
K-均值是發現給定數據集的k個簇的算法。簇個數k是用戶給定的,每一個簇通過其聚類中心,即簇中所有點的中心來描述。
K-均值算法的工作流程爲——首先,隨機確定k個初始點作爲質心。然後將數據集中的每個點分配到一個簇中,具體來講,爲每個點找距其最近的聚類中心,將其分配給該聚類中心所對應的簇。這一步完成之後,每個簇的聚類中心更新爲該簇所有點的平均值。
上述過程的僞代碼表示如下:
創建k個點作爲初始聚類中心(一般從樣本中隨機選擇,我的代碼裏是硬編碼爲一個數組)
當任意一個點的簇分配結果發生改變時
對數據集中的每個數據點
對每個聚類中心
計算聚類中心與數據點之間的距離
將數據點分配到距其最近的簇
對每一個簇,計算簇中所有點的均值並將均值作爲聚類中心
K-means的python實現:
# -*- coding: utf-8 -*-
from scipy.io import loadmat
from numpy import *
import matplotlib.pyplot as plt
def find_closest_centroids(centroids,data):
'''
找出每組數據最接近的聚類中心
Args:
centroids:聚類中心
data:數據集
Returns:
idx:數據集所屬聚類中心的下標
'''
data_size = len(data)
idx = zeros((data_size,1))
K = len(centroids)
for i in range(data_size):
temp = sum(power(data[i] - centroids,2),1) # 求每組數據距離的平方,求距離省略了開方過程,大小關係不受影響
idx[i] = temp.argmin() # 數據集所屬聚類中心的下標
return idx
def compute_centroids(data,idx,k):
'''
重新計算聚類中心
Args:
data:數據集
idx:數據集所屬聚類中心的下標
k:聚類數量
Returns:
centroids:更新後的聚類中心
'''
data_size = len(data)
cluster_sum = zeros((k,2)) # 一個簇中所有數據之和,用於求平均值
cluster_data_size = zeros((k,1)) # 一個簇包含的數據長度,用於求平均值
for i in range(data_size): # 遍歷每一組數據,計算cluster_sum和cluster_data_size
cluster_sum[int(idx[i])] = cluster_sum[int(idx[i])] + data[i]
cluster_data_size[int(idx[i])] = cluster_data_size[int(idx[i])] + 1
centroids = cluster_sum / cluster_data_size
return centroids
def figure_cluster(data,centroids,idx):
'''
在座標軸中根據聚類畫出數據樣本,描出聚類中心
Args:
data:數據集
centroids:聚類中心
idx:數據集所屬聚類中心的下標
Returns:
centroids:更新後的聚類中心
'''
data_size = len(data)
colors = ['r','b','y']
for i in range(data_size):
index = int(idx[i])
plt.scatter(data[i,0],data[i,1],marker='o',c=colors[index],s=15)
print(centroids[:,0],centroids[:,1])
plt.scatter(centroids[:,0],centroids[:,1],marker='x',color='000')
data_mat = loadmat("ex7data2.mat") # mat數據集
data = data_mat['X']
K = 3
initial_centroids = array([[3,3],[6,2],[8,5]]) # 硬編碼初始聚類中心
centroids = initial_centroids
'''
迭代,當更新的聚類中心不再變化時結束
'''
while True:
old_centroids = centroids
idx = find_closest_centroids(centroids,data)
centroids = compute_centroids(data,idx,K)
if (old_centroids == centroids).all():
break
figure_cluster(data,centroids,idx)
PS:此處所用的數據集是吳恩達機器學習課程中的數據集,是mat類型的數據,可自行更改爲自己的數據集