【數據挖掘】數據挖掘經典算法之The K-means algorithm

聚類

  聚類是一種無監督的學習,它將相似的對象歸到同一個簇中。聚類的方法幾乎可以應用於所有對象,簇內的對象越相似,聚類的效果越好。K-means(K-均值聚類)算法使一種聚類算法。之所以稱之爲K-均值使因爲它可以發現k個不同的簇,且每個簇的中心採用簇中所含值的均值計算而成。

K-均值聚類算法

  K-均值是發現給定數據集的k個簇的算法。簇個數k是用戶給定的,每一個簇通過其聚類中心,即簇中所有點的中心來描述。
  K-均值算法的工作流程爲——首先,隨機確定k個初始點作爲質心。然後將數據集中的每個點分配到一個簇中,具體來講,爲每個點找距其最近的聚類中心,將其分配給該聚類中心所對應的簇。這一步完成之後,每個簇的聚類中心更新爲該簇所有點的平均值。
  上述過程的僞代碼表示如下:

	創建k個點作爲初始聚類中心(一般從樣本中隨機選擇,我的代碼裏是硬編碼爲一個數組)
	當任意一個點的簇分配結果發生改變時
		對數據集中的每個數據點
			對每個聚類中心
				計算聚類中心與數據點之間的距離
			將數據點分配到距其最近的簇
		對每一個簇,計算簇中所有點的均值並將均值作爲聚類中心

K-means的python實現:

# -*- coding: utf-8 -*-

from scipy.io import loadmat
from numpy import *
import matplotlib.pyplot as plt


def find_closest_centroids(centroids,data): 
    '''
    找出每組數據最接近的聚類中心
    Args:
        centroids:聚類中心
        data:數據集
    Returns:
        idx:數據集所屬聚類中心的下標
    '''
    data_size = len(data)   
    idx = zeros((data_size,1))
    K = len(centroids)  
    
    for i in range(data_size):
        temp = sum(power(data[i] - centroids,2),1) # 求每組數據距離的平方,求距離省略了開方過程,大小關係不受影響
        idx[i] = temp.argmin() # 數據集所屬聚類中心的下標

    return idx

def compute_centroids(data,idx,k):
    '''
    重新計算聚類中心
    Args:
        data:數據集
        idx:數據集所屬聚類中心的下標
        k:聚類數量
    Returns:
        centroids:更新後的聚類中心
    '''
    data_size = len(data)
    cluster_sum = zeros((k,2)) # 一個簇中所有數據之和,用於求平均值
    cluster_data_size = zeros((k,1)) # 一個簇包含的數據長度,用於求平均值
    
    for i in range(data_size): # 遍歷每一組數據,計算cluster_sum和cluster_data_size
        cluster_sum[int(idx[i])] = cluster_sum[int(idx[i])] + data[i] 
        cluster_data_size[int(idx[i])] = cluster_data_size[int(idx[i])] + 1
    
    centroids = cluster_sum / cluster_data_size
    
    return centroids

def figure_cluster(data,centroids,idx):
    '''
    在座標軸中根據聚類畫出數據樣本,描出聚類中心
    
    Args:
        data:數據集
        centroids:聚類中心
        idx:數據集所屬聚類中心的下標
    Returns:
        centroids:更新後的聚類中心
    '''
    data_size = len(data)
    colors = ['r','b','y'] 
    
    for i in range(data_size):
        index = int(idx[i])
        plt.scatter(data[i,0],data[i,1],marker='o',c=colors[index],s=15)
    
    print(centroids[:,0],centroids[:,1])
    plt.scatter(centroids[:,0],centroids[:,1],marker='x',color='000')


data_mat = loadmat("ex7data2.mat") # mat數據集
data = data_mat['X']

K = 3

initial_centroids = array([[3,3],[6,2],[8,5]]) # 硬編碼初始聚類中心

centroids = initial_centroids

'''
迭代,當更新的聚類中心不再變化時結束
'''
while True:
    old_centroids = centroids
    idx = find_closest_centroids(centroids,data)
    centroids = compute_centroids(data,idx,K)

    if (old_centroids == centroids).all():
        break

figure_cluster(data,centroids,idx)

PS:此處所用的數據集是吳恩達機器學習課程中的數據集,是mat類型的數據,可自行更改爲自己的數據集

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章