k-means算法
K-均值聚類算法(k-means clustering algorithm)是一種無監督聚類算法。本文前部分介紹算法原理及優缺點,後面通過Python代碼實現一個簡版的k-means算法。
優缺點
- 優點:簡潔快速,算法的關鍵在於初始中心的選擇和距離度量。
- 缺點:
- K值(聚類的數目)需要事先確定。
- 聚類結果對初始類中心的選取較爲敏感。
- 容易陷入局部最優。
- 只能發現球型簇
- 時間複雜度:,其中,t爲迭代次數,K爲簇的數目,m爲記錄數,n爲維數。
- 空間複雜度:,其中,K爲簇的數目,m爲記錄數,n爲維數。
算法流程
-
1.隨機給定各簇中心 $ c_1, c_2,…, c_n $
-
2.計算各樣本點到簇中心的距離(一般採用歐式距離),將樣本點歸類到距離最小的簇中,公式如下:
其中,是使目標函數取最小值時的變量值
-
3.更新各簇中心$ c_1, c_2,…, c_n $
上式中,爲簇y的樣本總數
-
4.重複步驟2、3,直到達到收斂精度即可
停止條件
- 達到預先設置的最大迭代次數。
- 聚類中心不再發生變化。
- 相鄰兩次聚類結果中差值變化小於某一個閾值。
K值的確定
-
經驗法: 在實際工作中,結合業務的場景和需求,來決定幾類確定K值
-
肘部法則: 肘部法則通過成本函數來刻畫的,其是通過將不同K值的成本函數刻畫出來,隨着K值的增大,平均畸變程度的改善效果會不斷降低。因此。在找出K值增大的過程中,畸變程度下降幅度最大的位置所對應的K較爲合理。(注:成本函數爲各類的畸變程度之和與其內部成員位置距離的平方和,最優解是成本函數最小化爲目標。公式表示爲
其中,是第k個質心的位置。 -
規則法:
K-means算法Python實現
下面代碼將k-means封裝成類,實例化對象後調用fit()和transform()方法實現對數據的訓練和轉化,talk is checp, show you the code!
import numpy as np
from collections import Counter
import matplotlib.pyplot as plt
class Kmeans():
def __init__(self, n_clusters=8, max_iter=10, tol=1e-4):
self.__n_cluster = n_clusters
self.__max_iter = max_iter
self.__tol = tol
self.__cluster_centers_ = None
self.__n_iter_ = 0
self.__label = None
self.info = ''
@property
def cluster_centers_(self):
return self.__cluster_centers_
@property
def n_iter(self):
return self.__n_iter_
@property
def label(self):
return self.__label
def __generate_cluster_centers(self, data):
"""隨機生成類中心"""
m, n = data.shape
col_max_value = data.max(axis=0)
col_min_value = data.min(axis=0)
cluster_centers = np.empty(shape=[self.__n_cluster, n])
for col_index in range(n):
col_value = np.random.uniform(col_min_value[col_index],
col_max_value[col_index],
self.__n_cluster)
cluster_centers[:, col_index] = col_value
return cluster_centers
def __calu_distance(self, sample, cluster_center):
"""
計算一個樣本sample到一個聚類中心的距離(歐式距離)
:param sample: 輸入的一個樣本
:return: 距離
"""
return np.sqrt(np.sum(np.square(cluster_center - sample)))
def __category(self, sample):
"""
計算一個sample所屬的簇
:param sample:
:return: 簇索引
"""
distance_list = []
for cluster_center_index in range(self.__n_cluster):
cluster_center = self.__cluster_centers_[cluster_center_index, :]
distance_list.append(self.__calu_distance(sample=sample, cluster_center=cluster_center))
return distance_list.index(min(distance_list))
def __update_cluster_centers(self, data, label):
"""
計算類的中心
:param label:
:return:
"""
cluster_centers = np.empty(shape=[self.__n_cluster, data.shape[1]])
for cluster_index in range(self.__n_cluster):
sample_index_array = np.where(label == cluster_index)
samples = data[sample_index_array]
center = samples.mean(axis=0)
cluster_centers[cluster_index] = center
return cluster_centers
def __is_terminate(self, cluster_centers, label):
"""
是否滿足終止條件(1. 達到最大迭代次數; 2. 簇中心不再變化; 3. 兩次聚類結果差值小於某一個閾值
:param label:
:return:
"""
diff = self.__calu_diff(label)
if self.__n_iter_ >= self.__max_iter:
self.info = '達到最大迭代次數: ' + str(self.__max_iter) + ',聚類結束。'
return True
if (self.__cluster_centers_ == cluster_centers).all():
self.info = '兩次聚類簇中心未發生變化,聚類結束。'
return True
if diff <= self.__tol:
self.info = '聚類結果差值爲 {},小於 {},聚類結束。'.format(diff, self.__tol)
return True
self.__label = label
return False
def __calu_diff(self, label):
"""
計算兩次聚類結果的差值
:param label: 最近的一次聚類結果
:return: 差值佔比
"""
ans = label == self.__label
length = len(ans)
ans_dict = Counter(ans)
diff_num = ans_dict.get(False)
if diff_num == None:
return 0
else:
return diff_num / length
def fit(self, data):
"""
將數據分簇,並計算簇的中心
:param data: 數據
:return:
"""
label = np.array([np.nan] * data.shape[0])
self.__cluster_centers_ = self.__generate_cluster_centers(data=data)
while True:
# 將每個樣本劃分到最近的簇中
for sample_index in range(data.shape[0]):
category = self.__category(sample=data[sample_index])
label[sample_index] = category
# 更新簇中心
cluster_centers = self.__update_cluster_centers(data=data, label=label)
self.__n_iter_ += 1
# 檢查是否滿足終止條件
if self.__is_terminate(cluster_centers=cluster_centers, label=label):
self.__cluster_centers_ = cluster_centers
self.print_info()
break
self.__cluster_centers_ = cluster_centers
self.__label = label.copy()
def transform(self, data):
"""
計算每個數據到各聚類中心的距離
:param data:
:return:
"""
ans = np.empty([data.shape[0], self.__n_cluster])
for index in range(data.shape[0]):
tmp = []
for center in self.__cluster_centers_:
tmp.append(self.__calu_distance(sample=data[index], cluster_center=center))
center_array = np.array(tmp)
ans[index] = center_array
return ans
def predict(self, data):
"""
預測數據所屬的簇
:param data: 待預測數據
:return: 數據所屬類別
"""
distiance = self.transform(data)
cluster_labels = np.argmax(distiance, axis=1)
return cluster_labels
def print_info(self):
print('clustering is terminate!!!')
print('結束原因:{}'.format(self.info))
print('迭代次數 {}'.format(self.__n_iter_))
print("簇中心:\n", self.__cluster_centers_)
隨機生成正太分佈的數據
def generate_test_data(size=(120, 2), clusters=3, loc=10, scale=10):
cluster_center = np.array([[10, 10], [50, 100], [100, 50]])
data = np.empty(shape=[0, size[1]])
for cluster_index in range(clusters):
new_data_x = np.random.normal(loc=cluster_center[cluster_index][0],
scale=scale, size=[int(size[0] / clusters),1])
new_data_y = np.random.normal(loc=cluster_center[cluster_index][1],
scale=scale, size=[int(size[0] / clusters), 1])
new_data = np.concatenate((new_data_x, new_data_y), axis=1)
data = np.concatenate([data, new_data], axis=0)
# print(data.shape)
return data
聚類結果可視化函數
def show_data(data, label, cluster_centers=[]):
filled_markers = ('o', 'v', '^', '<', '>', '8', 's', 'p', 'h', 'H', 'D', 'd', 'P', 'X')
filled_colors = ('black', 'green', 'darkorchid', 'darkred', 'darksalmon', 'darkslategray')
plt.figure()
for sample, sample_label in zip(data, label):
plt.scatter(sample[0], sample[1], marker=filled_markers[sample_label],
color=filled_colors[sample_label], s=40, label='cluster_' + str(sample_label))
if len(cluster_centers) > 0:
for center in cluster_centers:
plt.scatter(center[0], center[1], marker='*', color='red', s=80)
# plt.legend(scatterpoints=1)
plt.show()
主函數調用
if __name__ == '__main__':
model = Kmeans(n_clusters=3, max_iter=10000)
data = generate_test_data(scale=15)
model.fit(data)
label = model.label.astype('int')
centers = model.cluster_centers_
show_data(data, label, centers)
運行上面代碼段,輸出下面結果
clustering is terminate!!!
結束原因:兩次聚類簇中心未發生變化,聚類結束。
迭代次數 5
簇中心:
[[100.46756833 48.52330939]
[ 49.11200803 101.26172269]
[ 9.27754808 9.65780098]]