混淆矩陣(Confusion Matrix)是機器學習中用來總結分類模型預測結果的一個分析表,是模式識別領域中的一種常用的表達形式。它以矩陣的形式描繪樣本數據的真實屬性和分類預測結果類型之間的關係,是用來評價分類器性能的一種常用方法。
我們可以通過一個簡單的例子來直觀理解混淆矩陣。
通過分類模型我們得到的預測結果以及真實的屬性可以通過列表的形式展現,
y_pred=["ant", "ant", "cat", "cat", "ant", "cat"] #預測
y_true=["cat", "ant", "cat", "cat", "ant", "bird"] #真實
數軸的標籤表示真實屬性,而橫軸的標籤表示分類的預測結果。此矩陣的第一行第一列這個數字2表示ant被成功分類成爲ant的樣本數目,第三行第一列的數字1表示cat被分類成ant的樣本數目,諸如此類。
混淆矩陣的每一行數據之和代表該類別的真實的數目,每一列之和代表該類別的預測的數目,矩陣的對角線上的數值代表被正確預測的樣本數目。
那麼這個混淆矩陣是如何繪製的呢?
這裏給出兩種簡單的方法,一是使用seaborn的熱力圖來繪製,可以直接將混淆矩陣可視化;
C=confusion_matrix(y_true, y_pred, labels=["ant", "bird", "cat"])
df=pd.DataFrame(C,index=["ant", "bird", "cat"],columns=["ant", "bird", "cat"])
sns.heatmap(df,annot=True)
另外一種是使用matplotlib的matshow來繪製。
plt.matshow(C, cmap=plt.cm.Greens)
plt.colorbar()
for i in range(len(C)):
for j in range(len(C)):
plt.annotate(C[i,j], xy=(i, j), horizontalalignment='center', verticalalignment='center')
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
效果如下:
利用混淆矩陣的可視化,我們能夠分析類別誤判的結果,從而對機器學習的模型進行調整。