文章目錄
操作平臺: windows10, python37, jupyter
一、K-means算法原理
聚類的概念:一種無監督的學習,事先不知道類別,自動將相似的對象歸到同一個簇中。
- K-Means算法是一種聚類分析(cluster analysis)的算法,其主要是來計算數據聚集的算法,主要通過不斷地取離種子點最近均值的算法。
- K-Means算法主要解決的問題如下圖所示。我們可以看到,在圖的左邊有一些點,我們用肉眼可以看出來有四個點羣,但是我們怎麼通過計算機程序找出這幾個點羣來呢?於是就出現了我們的K-Means算法
- 這個算法其實很簡單,如下圖所示:
從上圖中,我們可以看到,A,B,C,D,E是五個在圖中點。而灰色的點是我們的種子點,也就是我們用來找點羣的點。有兩個種子點,所以K=2。
- GIF動畫實例:
它們先分別連接離自己最近的點,然後移動到中間,再放棄不是離自己最近的點,重新連接離自己近的點。
然後,K-Means的算法如下:
- 隨機在圖中取K(這裏K=2)個種子點。
- 然後對圖中的所有點求到這K個種子點的距離,假如點Pi離種子點Si最近,那麼Pi屬於Si點羣。(上圖中,我們可以看到A,B屬於上面的種子點,C,D,E屬於下面中部的種子點)
- 接下來,我們要移動種子點到屬於他的“點羣”的中心。(見圖上的第三步)
- 然後重複第2)和第3)步,直到,種子點沒有移動(我們可以看到圖中的第四步上面的種子點聚合了A,B,C,下面的種子點聚合了D,E)。
這個算法很簡單,重點說一下“求點羣中心的算法”:歐氏距離(Euclidean Distance):差的平方和的平方根
K-Means主要最重大的缺陷——都和初始值有關:
-
K是事先給定的,這個K值的選定是非常難以估計的。很多時候,事先並不知道給定的數據集應該分成多少個類別才最合適。(ISODATA算法通過類的自動合併和分裂,得到較爲合理的類型數目K)
-
K-Means算法需要用初始隨機種子點來搞,這個隨機種子點太重要,不同的隨機種子點會有得到完全不同的結果。(K-Means++算法可以用來解決這個問題,其可以有效地選擇初始點)
總結:K-Means算法步驟:
- 從數據中選擇k個對象作爲初始聚類中心;
- 計算每個聚類對象到聚類中心的距離來劃分;
- 再次計算每個聚類中心
- 計算標準測度函數,直到達到最大迭代次數,則停止,否則,繼續操作。
- 確定最優的聚類中心
K-Means算法應用:
看到這裏,你會說,K-Means算法看來很簡單,而且好像就是在玩座標點,沒什麼真實用處。而且,這個算法缺陷很多,還不如人工呢。是的,前面的例子只是玩二維座標點,的確沒什麼意思。但是你想一下下面的幾個問題:
1)如果不是二維的,是多維的,如5維的,那麼,就只能用計算機來計算了。
2)二維座標點的X,Y 座標,其實是一種向量,是一種數學抽象。現實世界中很多屬性是可以抽象成向量的,比如,我們的年齡,我們的喜好,我們的商品,等等,能抽象成向量的目的就是可以讓計算機知道某兩個屬性間的距離。如:我們認爲,18歲的人離24歲的人的距離要比離12歲的距離要近,鞋子這個商品離衣服這個商品的距離要比電腦要近,等等。
二、實戰1
2.1、make_blobs隨機生成點
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
# 無監督學習,只有分類
from sklearn.cluster import KMeans
import sklearn.datasets as datasets
#隨機生成點
X,y = datasets.make_blobs()
#畫散點圖
plt.scatter(X[:,0],X[:,1]) #如果加上 c = y ,那麼散點自帶顏色
2.2、聚類
kmeans = KMeans(3) #聚成3類
kmeans.fit(X)
y_ = kmeans.predict(X)
plt.scatter(X[:,0],X[:,1],c = y_) #畫散點圖
2.3、輪廓係數
- 當一個數據我們不知道聚成多少類合適時,我們可以通過輪廓係數的最大值來確定,輪廓係數越大,說明聚類越好。
from sklearn.metrics import silhouette_score
kmeans = KMeans(2) #聚成兩類
kmeans.fit(X)
y_ = kmeans.predict(X)
print (silhouette_score(X,y_)) # 查看輪廓係數
0.7182695065525002
三、實戰2
Kmeans將亞洲足球隊自動分組
3.1、導入數據
#導入庫
import numpy as np
import pandas as pd
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
%matplotlib inline
#導入數據,並添加表頭
data = pd.read_csv('./Asiafootball.txt',header = None,names = ['國家','2006','2010','2007'])
data
國家 | 2006 | 2010 | 2007 | |
---|---|---|---|---|
0 | 中國 | 50 | 50 | 9 |
1 | 日本 | 28 | 9 | 4 |
2 | 韓國 | 17 | 15 | 3 |
3 | 伊朗 | 25 | 40 | 5 |
4 | 沙特 | 28 | 40 | 2 |
5 | 伊拉克 | 50 | 50 | 1 |
6 | 卡塔爾 | 50 | 40 | 9 |
7 | 阿聯酋 | 50 | 40 | 9 |
8 | 烏茲別克斯坦 | 40 | 40 | 5 |
9 | 泰國 | 50 | 50 | 9 |
10 | 越南 | 50 | 50 | 5 |
11 | 阿曼 | 50 | 50 | 9 |
12 | 巴林 | 40 | 40 | 9 |
13 | 朝鮮 | 40 | 32 | 17 |
14 | 印尼 | 50 | 50 | 9 |
3.2、切分數據
X = data.iloc[:,1:] #把國家切掉
X.head(3) #顯示前三行數據
2006 | 2010 | 2007 | |
---|---|---|---|
0 | 50 | 50 | 9 |
1 | 28 | 9 | 4 |
2 | 17 | 15 | 3 |
3.3、聚類
kmeans = KMeans(n_clusters=3)
kmeans.fit(X)
y_ = kmeans.predict(X)
y_
array([2, 1, 1, 0, 0, 2, 2, 2, 0, 2, 2, 2, 0, 0, 2])
結果分析: 現在已經上面的數據已經被分爲三類了,得到每個的索引。
3.4、查看所有的國家
c = data['國家'].values
c
array(['中國', '日本', '韓國', '伊朗', '沙特', '伊拉克', '卡塔爾', '阿聯酋', '烏茲別克斯坦', '泰國',
'越南', '阿曼', '巴林', '朝鮮', '印尼'], dtype=object)
3.5、聚類國家
- 根據國家的戰況,放戰力差不多的國家聚在一起。
for i in range(3):
print(c[np.argwhere(y_ == i).ravel()])
['伊朗' '沙特' '烏茲別克斯坦' '巴林' '朝鮮']
['日本' '韓國']
['中國' '伊拉克' '卡塔爾' '阿聯酋' '泰國' '越南' '阿曼' '印尼']