目錄
本篇理論性不多,主要是部分總結及實戰內容。
一、EM算法的步驟
EM算法(英文叫做Expectation Maximization,最大期望算法)三個主要的步驟:
- 初始化參數
- 觀察預期
- 重新估計
二、EM算法的工作原理
EM算法一般用於聚類,也就是無監督模型裏面,因爲無監督學習沒有標籤,EM算法可以先給無監督學習估計一個隱狀態(即標籤),有了標籤,算法模型就可以轉換成有監督學習,這時就可以用極大似然估計法求解出模型最優參數。其中估計隱狀態流程應爲EM算法的E步,後面用極大似然估計爲M步。
下面介紹一下兩種不同類型的聚類算法:
- 硬聚類算法:如K-Means ,是通過距離來區分樣本之間的差別的,且每個樣本在計算的時候只能屬於一個分類。
- 軟聚類算法:如EM 聚類,它在求解的過程中,實際上每個樣本都有一定的概率和每個聚類相關。
EM 算法相當於一個框架,你可以採用不同的模型來進行聚類,比如 GMM(高斯混合模型),或者 HMM(隱馬爾科夫模型)來進行聚類。
- GMM 是通過概率密度來進行聚類,聚成的類符合高斯分佈(正態分佈)。
- HMM 用到了馬爾科夫過程,在這個過程中,我們通過狀態轉移矩陣來計算狀態轉移的概率。
三、在sklearn中創建GMM模型
本案例採用GMM高斯混合模型。因此將介紹下如何在sklearn中創建GMM聚類。
gmm = GaussianMixture(n_components=1, covariance_type='full', max_iter=100)
看一下這幾個參數:
1. n_components:即高斯混合模型的個數,也就是我們要聚類的個數,默認值爲 1。如果你不指定 n_components,最終的聚類結果都會爲同一個值。
2. covariance_type:代表協方差類型。一個高斯混合模型的分佈是由均值向量和協方差矩陣決定的,所以協方差的類型也代表了不同的高斯混合模型的特徵。協方差類型有 4 種取值:
- covariance_type=full,代表完全協方差,也就是元素都不爲 0;
- covariance_type=tied,代表相同的完全協方差;
- covariance_type=diag,代表對角協方差,也就是對角不爲 0,其餘爲 0;
- covariance_type=spherical,代表球面協方差,非對角爲 0,對角完全相同,呈現球面的特性。
3. max_iter:代表最大迭代次數,EM 算法是由 E 步和 M 步迭代求得最終的模型參數,這裏可以指定最大迭代次數,默認值爲 100。
創建完GMM聚類器之後,可以傳入數據讓它進行迭代擬合。我們使用 fit 函數,傳入樣本特徵矩陣,模型會自動生成聚類器,然後使用 prediction=gmm.predict(data) 來對數據進行聚類,傳入你想進行聚類的數據,可以得到聚類結果 prediction。
四、工作流程
我們使用王者榮耀英雄數據集來進行聚類,數據包括 69 名英雄的 23 個特徵屬性。這些屬性分別是,英雄,最大生命,生命成長,初始生命,最大法力,法力成長,初始法力,最高物攻,物攻成長,初始物攻,最大物防,物防成長,初始物防,最大每5秒回血,每5秒回血成長,初始每5秒回血,最大每5秒回藍,每5秒回藍成長,初始每5秒回藍,最大攻速,攻擊範圍,主要定位,次要定位。
王者榮耀英雄數據集鏈接:https://github.com/cystanford/EM_data
先劃分一下流程:
整個訓練過程基本上都會包括三個階段:
-
首先加載數據集;
-
在準備階段,我們需要對數據進行探索,包括採用數據可視化技術,讓我們對英雄屬性以及這些屬性之間的關係理解更加深刻,然後對數據質量進行評估,是否進行數據清洗,最後進行特徵選擇方便後續的聚類算法;
-
聚類階段:選擇適合的聚類模型,這裏我們採用 GMM 高斯混合模型進行聚類,並輸出聚類結果,對結果進行分析。
五、實戰環節
1. 導包
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.mixture import GaussianMixture
2. 加載數據
# 數據加載,避免中文亂碼問題
data_original = pd.read_csv('dataset/heros.csv', encoding = 'gb18030')
3. 數據可視化分析
data_original.head().append(data_original.tail()) # 顯示前5行和後5行數據
英雄 | 最大生命 | 生命成長 | 初始生命 | 最大法力 | 法力成長 | 初始法力 | 最高物攻 | 物攻成長 | 初始物攻 | ... | 最大每5秒回血 | 每5秒回血成長 | 初始每5秒回血 | 最大每5秒回藍 | 每5秒回藍成長 | 初始每5秒回藍 | 最大攻速 | 攻擊範圍 | 主要定位 | 次要定位 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 夏侯惇 | 7350 | 288.8 | 3307 | 1746 | 94 | 430 | 321 | 11.570 | 159 | ... | 98 | 3.357 | 51 | 37 | 1.571 | 15 | 28.00% | 近戰 | 坦克 | 戰士 |
1 | 鍾無豔 | 7000 | 275.0 | 3150 | 1760 | 95 | 430 | 318 | 11.000 | 164 | ... | 92 | 3.143 | 48 | 37 | 1.571 | 15 | 14.00% | 近戰 | 戰士 | 坦克 |
2 | 張飛 | 8341 | 329.4 | 3450 | 100 | 0 | 100 | 301 | 10.570 | 153 | ... | 115 | 4.143 | 57 | 5 | 0.000 | 5 | 14.00% | 近戰 | 坦克 | 輔助 |
3 | 牛魔 | 8476 | 352.8 | 3537 | 1926 | 104 | 470 | 273 | 8.357 | 156 | ... | 117 | 4.214 | 58 | 42 | 1.786 | 17 | 14.00% | 近戰 | 坦克 | 輔助 |
4 | 呂布 | 7344 | 270.0 | 3564 | 0 | 0 | 0 | 343 | 12.360 | 170 | ... | 97 | 3.071 | 54 | 0 | 0.000 | 0 | 14.00% | 近戰 | 戰士 | 坦克 |
64 | 阿軻 | 5968 | 192.8 | 3269 | 0 | 0 | 0 | 427 | 17.860 | 177 | ... | 81 | 2.214 | 50 | 0 | 0.000 | 0 | 28.00% | 近戰 | 刺客 | NaN |
65 | 娜可露露 | 6205 | 211.9 | 3239 | 1808 | 97 | 450 | 385 | 15.140 | 173 | ... | 79 | 2.286 | 47 | 38 | 1.571 | 16 | 14.00% | 近戰 | 刺客 | NaN |
66 | 蘭陵王 | 6232 | 210.0 | 3292 | 1822 | 98 | 450 | 388 | 15.500 | 171 | ... | 99 | 3.357 | 52 | 46 | 1.929 | 19 | 14.00% | 近戰 | 刺客 | NaN |
67 | 鎧 | 6700 | 237.5 | 3375 | 1784 | 96 | 440 | 328 | 10.860 | 176 | ... | 81 | 2.643 | 44 | 38 | 1.571 | 16 | 28.00% | 近戰 | 戰士 | 坦克 |
68 | 百里守約 | 5611 | 185.1 | 3019 | 1784 | 96 | 440 | 410 | 15.860 | 188 | ... | 68 | 2.071 | 39 | 38 | 1.571 | 16 | 28.00% | 遠程 | 射手 | 刺客 |
# 對英雄屬性之間的關係進行可視化分析
# 設置plt正確顯示中文
plt.rcParams['font.sans-serif']=['SimHei'] #用來正常顯示中文標籤
plt.rcParams['axes.unicode_minus']=False #用來正常顯示負號
# 用熱力圖呈現特徵之間的相關性
corr = data.corr()
plt.figure(figsize=(14,14))
# annot=True顯示每個方格的數據
sns.heatmap(corr, annot=True)
plt.show()
我們將 18 個英雄屬性之間的關係用熱力圖呈現了出來,中間的數字代表兩個屬性之間的關係係數,最大值爲 1,代表完全正相關,關係係數越大代表相關性越大。
從圖中你能看出來“最大生命”“生命成長”和“初始生命”這三個屬性的相關性大,我們只需要保留一個屬性即可。同理我們也可以對其他相關性大的屬性進行篩選,保留一個。 這既是對原有屬性進行降維。
4. 特徵工程
# 相關性大的屬性保留一個,因此可以對屬性進行降維
features_remain = [u'最大生命', u'初始生命', u'最大法力', u'最高物攻', u'初始物攻', u'最大物防', u'初始物防',
u'最大每5秒回血', u'最大每5秒回藍', u'初始每5秒回藍', u'最大攻速', u'攻擊範圍']
data = data_original[features_remain]
data
最大生命 | 初始生命 | 最大法力 | 最高物攻 | 初始物攻 | 最大物防 | 初始物防 | 最大每5秒回血 | 最大每5秒回藍 | 初始每5秒回藍 | 最大攻速 | 攻擊範圍 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 7350 | 3307 | 1746 | 321 | 159 | 397 | 101 | 98 | 37 | 15 | 28.00% | 近戰 |
1 | 7000 | 3150 | 1760 | 318 | 164 | 409 | 100 | 92 | 37 | 15 | 14.00% | 近戰 |
2 | 8341 | 3450 | 100 | 301 | 153 | 504 | 125 | 115 | 5 | 5 | 14.00% | 近戰 |
3 | 8476 | 3537 | 1926 | 273 | 156 | 394 | 109 | 117 | 42 | 17 | 14.00% | 近戰 |
4 | 7344 | 3564 | 0 | 343 | 170 | 390 | 99 | 97 | 0 | 0 | 14.00% | 近戰 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
64 | 5968 | 3269 | 0 | 427 | 177 | 349 | 89 | 81 | 0 | 0 | 28.00% | 近戰 |
65 | 6205 | 3239 | 1808 | 385 | 173 | 359 | 86 | 79 | 38 | 16 | 14.00% | 近戰 |
66 | 6232 | 3292 | 1822 | 388 | 171 | 342 | 85 | 99 | 46 | 19 | 14.00% | 近戰 |
67 | 6700 | 3375 | 1784 | 328 | 176 | 388 | 107 | 81 | 38 | 16 | 28.00% | 近戰 |
68 | 5611 | 3019 | 1784 | 410 | 188 | 329 | 94 | 68 | 38 | 16 | 28.00% | 遠程 |
5. 數據規範化
我們能看到“最大攻速”這個屬性值是百分數,不適合做矩陣運算,因此我們需要將百分數轉化爲小數。我們也看到“攻擊範圍”這個字段的取值爲遠程或者近戰,也不適合矩陣運算,我們將取值做個映射,用 1 代表遠程,0 代表近戰。然後採用 Z-Score 規範化,對特徵矩陣進行規範化。
data[u'最大攻速'] = data[u'最大攻速'].apply(lambda x: float(x.strip('%'))/100)
data[u'攻擊範圍'] = data[u'攻擊範圍'].map({'遠程':1,'近戰':0})
# 採用Z-Score規範化數據,保證每個特徵維度的數據均值爲0,方差爲1
ss = StandardScaler()
data = ss.fit_transform(data)
6. 建模併產生結果,寫入文件
# 構造GMM聚類
gmm = GaussianMixture(n_components=30, covariance_type='full')
gmm.fit(data)
# 訓練數據
prediction = gmm.predict(data)
print(prediction)
# 將分組結果輸出到CSV文件中
data_original.insert(0, '分組', prediction)
data_original.to_csv('./EM_data/heros_out.csv', index=False, sep=',')
[ 2 13 6 8 26 3 0 6 21 13 7 13 21 20 17 7 27 21 26 5 9 5 5 5
5 5 5 1 4 23 20 4 16 4 23 4 4 19 12 16 16 4 4 4 29 16 13 12
13 29 24 14 10 11 11 2 25 13 22 26 25 10 15 2 18 14 14 28 1]
我們採用了 GMM 高斯混合模型,並將結果輸出到 CSV 文件中。聚類個數爲 30。
7. 顯示聚類後的結果
data_group = pd.read_csv('./EM_data/heros_out.csv')
data_group.sort_values('分組')
分組 | 英雄 | 最大生命 | 生命成長 | 初始生命 | 最大法力 | 法力成長 | 初始法力 | 最高物攻 | 物攻成長 | ... | 最大每5秒回血 | 每5秒回血成長 | 初始每5秒回血 | 最大每5秒回藍 | 每5秒回藍成長 | 初始每5秒回藍 | 最大攻速 | 攻擊範圍 | 主要定位 | 次要定位 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
6 | 0 | 羋月 | 6164 | 281.5 | 3105 | 100 | 0 | 100 | 289 | 9.786 | ... | 77 | 2.357 | 44 | 0 | 0.000 | 0 | 0.00% | 遠程 | 法師 | 坦克 |
68 | 1 | 百里守約 | 5611 | 185.1 | 3019 | 1784 | 96 | 440 | 410 | 15.860 | ... | 68 | 2.071 | 39 | 38 | 1.571 | 16 | 28.00% | 遠程 | 射手 | 刺客 |
27 | 1 | 成吉思汗 | 5799 | 198.0 | 3027 | 1742 | 93 | 440 | 394 | 15.000 | ... | 66 | 2.071 | 37 | 36 | 1.500 | 15 | 42.00% | 遠程 | 射手 | NaN |
63 | 2 | 哪吒 | 7268 | 270.4 | 3483 | 1808 | 97 | 450 | 320 | 11.500 | ... | 98 | 3.214 | 53 | 38 | 1.571 | 16 | 28.00% | 近戰 | 戰士 | NaN |
55 | 2 | 楊戩 | 7420 | 291.5 | 3339 | 1694 | 91 | 420 | 325 | 11.360 | ... | 98 | 3.357 | 51 | 36 | 1.500 | 15 | 28.00% | 近戰 | 戰士 | NaN |
0 | 2 | 夏侯惇 | 7350 | 288.8 | 3307 | 1746 | 94 | 430 | 321 | 11.570 | ... | 98 | 3.357 | 51 | 37 | 1.571 | 15 | 28.00% | 近戰 | 坦克 | 戰士 |
5 | 3 | 亞瑟 | 8050 | 316.3 | 3622 | 0 | 0 | 0 | 346 | 13.000 | ... | 106 | 3.643 | 55 | 0 | 0.000 | 0 | 14.00% | 近戰 | 戰士 | 坦克 |
31 | 4 | 甄姬 | 5584 | 181.6 | 3041 | 2002 | 108 | 490 | 296 | 9.357 | ... | 71 | 2.000 | 43 | 44 | 1.857 | 18 | 14.00% | 遠程 | 法師 | NaN |
33 | 4 | 干將莫邪 | 5583 | 171.0 | 3189 | 1946 | 104 | 490 | 292 | 9.500 | ... | 71 | 1.857 | 45 | 41 | 1.714 | 17 | 14.00% | 遠程 | 法師 | NaN |
41 | 4 | 小喬 | 5916 | 202.0 | 3088 | 1988 | 107 | 490 | 263 | 7.857 | ... | 75 | 2.214 | 44 | 43 | 1.786 | 18 | 14.00% | 遠程 | 法師 | NaN |
第一列代表的是分組(簇),我們能看到百里守約和成吉思汗分到了一組,哪吒、楊戩和夏侯惇是一組,亞瑟自己是一組,甄姬、干將莫邪和小喬是一組。
聚類的特點是相同類別之間的屬性值相近,不同類別的屬性值差異大。
8. 聚類結果的評估
聚類和分類不一樣,聚類是無監督的學習方式,也就是我們沒有實際的結果可以進行比對,所以聚類的結果評估不像分類準確率一樣直觀,那麼有沒有聚類結果的評估方式呢?這裏我們可以採用 Calinski-Harabaz 指標,代碼如下:
from sklearn.metrics import calinski_harabaz_score
print(calinski_harabaz_score(data, prediction))
20.273576816244606
指標分數越高,代表聚類效果越好,也就是相同類中的差異性小,不同類之間的差異性大。當然具體聚類的結果含義,我們需要人工來分析,也就是當這些數據被分成不同的類別之後,具體每個類表代表的含義。