precision_at_k

假設輸出節點是三個, 即對應的是一個三分類問題。假設有10條數據,輸出層的結果如下表所示

預測 0 1 2 實際標籤
1 3.2901842e-04 6.2009683e-05 9.9960905e-01 2
2 9.3300034e-05 6.3389372e-05 9.9984324e-01 2
3 1.8921893e-04 1.3000969e-04 9.9968076e-01] 0
4 7.2724116e-04 3.1397035e-04 9.9895883e-01 0
5 1.7774174e-03 5.5241789e-04 9.9767011e-01 0
6 1.5353756e-03 3.8222148e-04 9.9808240e-01 2
7 5.0105453e-03 8.7840611e-04 9.9411100e-01 2
8 1.1946440e-04 1.0893926e-04 9.9977165e-01 0
9 4.0907355e-04 2.3722800e-04 9.9935371e-01 0
10 5.0522230e-04 3.5394888e-04 9.9914086e-01 2

來看下k取不同值時,percision_at_k()的結果

tf.metrics.precision_at_k(labels=label_ids, predictions=logits,k=?) k=1 k =2 k=3
結果 0.50 0.45 0.33
計算 第1,2,6,7,10概率最大的標籤是2與真實標籤一致,所以tp=5.其他的幾條雖然最大標籤也爲2,但與真實標籤不一致, fp=5 1-10每一條的數據取概率最大的兩個標籤,正確的只有一個,另一個是不正確的。tp=10, fp=10。 同k=2, 每條數據的預測值都是正確的只有一個標籤, 錯誤的是兩個。故tp=10, fp=20
最終 pk=1=tptp+fp=55+5 p_{k=1}=\frac{tp}{tp+fp}=\frac5{5+5} pk=2=tptp+fp=1010+10p_{k=2}=\frac{tp}{tp+fp}=\frac{10}{10+10} pk=3=tptp+fp=1010+20p_{k=3}=\frac{tp}{tp+fp}=\frac{10}{10+20}

其實吧, 這個指標不太適合來衡量多分類結果(個人愚見, 歡迎指正)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章