混淆矩陣含義及python代碼實現

一、分類評估指標中定義的一些符號含義

  1. TP:將正類預測爲正類數,真實爲0,預測也爲0
  2. FN:將正類預測爲負類數,真實爲0,預測爲1
  3. FP:將負類預測爲正類數, 真實爲1,預測爲0
  4. TN:將負類預測爲負類數,真實爲1,預測也爲1

 二、混淆矩陣定義及表示含義

混淆矩陣是機器學習中總結分類模型預測結果的情形分析表,以矩陣形式將數據集中的記錄按照真實的類別與分類模型預測的類別判斷兩個標準進行彙總。其中矩陣的行表示真實值,矩陣的列表示預測值。

二分類問題:

  混淆矩陣

     預測值

正(貓) 負(狗)

正(貓)      3     0
負(狗)       1      2

通過混淆矩陣我們可以輕鬆算的真實值貓的數量(行數量相加)爲3=3+0,分類得到貓的數量(列數量相加)爲4=3+1。真實狗的數量爲3=1+2,分類得到狗的數量爲2=0+2。同時,我們不難發現,對於二分類問題,矩陣中的4個元素剛好表示TP,TN,FP,TN這四個符號量 。

  混淆矩陣

      預測值

      正     負

     正       TP(a)     FN(b)
     負       FP(c)      TN(d)

則:精確率:Precision=a/(a+c)=TP/(TP+FP)

召回率:recall=a/(a+b)=TP/(TP+FN)

準確率:accuracy=(a+d)/(a+b+c+d)=(TP+TN)/(TP+FN+FP+TN)

多分類問題: 

   混淆

   矩陣

       預測值

類別1 類別2 類別3

類別1     a     b      c
類別2    d     e      f
類別3     g      h      i

矩陣行數據相加是真實值類別數,召回率_類別1=a/(a+b+c) 。列數據相加是分類後的類別數, 精確率_類別1=a/(a+d+g)。對角線相加是分類準確率,準確率accuracy=(a+e+i)/(a+b+c+d+e+f+g+h+i)

三、Python代碼實現混淆矩陣

sklearn.metrics.confusion_matrix(y_true, y_pred, labels=None, sample_weight=None)

y_true:是樣本真實分類結果,y_pred 是樣本預測分類結果 ,labels是所給出的類別,通過這個可對類別進行選擇 ,sample_weight 是樣本權重。我們用confusion_matrix生成矩陣數據,然後用seaborn的熱度圖繪製出混淆矩陣數據:

import seaborn as sns
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
sns.set()
y_true = ["cat", "dog", "cat", "cat", "dog", "rebit"]
y_pred = ["dog", "dog", "rebit", "cat", "dog", "cat"]
C2= confusion_matrix(y_true, y_pred, labels=["dog", "rebit", "cat"])
sns.heatmap(C2,annot=True)

參考文獻:

https://baijiahao.baidu.com/s?id=1619821729031070174&wfr=spider&for=pc

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章