直觀理解爲什麼分類問題用交叉熵損失而不用均方誤差損失?


博客:blog.shinelee.me | 博客園 | CSDN

交叉熵損失與均方誤差損失

常規分類網絡最後的softmax層如下圖所示,傳統機器學習方法以此類比,

https://stats.stackexchange.com/questions/273465/neural-network-softmax-activation

一共有KK類,令網絡的輸出爲[y^1,,y^K][\hat{y}_1,\dots, \hat{y}_K],對應每個類別的概率,令label爲 [y1,,yK][y_1, \dots, y_K]。對某個屬於pp類的樣本,其label中yp=1y_p=1y1,,yp1,yp+1,,yKy_1, \dots, y_{p-1}, y_{p+1}, \dots, y_K均爲0。

對這個樣本,交叉熵(cross entropy)損失
L=(y1logy^1++yKlogy^K)=yplogy^p=logy^p \begin{aligned}L &= - (y_1 \log \hat{y}_1 + \dots + y_K \log \hat{y}_K) \\&= -y_p \log \hat{y}_p \\ &= - \log \hat{y}_p\end{aligned}
**均方誤差損失(mean squared error,MSE)**爲
L=(y1y^1)2++(yKy^K)2=(1y^p)2+(y^12++y^p12+y^p+12++y^K2) \begin{aligned}L &= (y_1 - \hat{y}_1)^2 + \dots + (y_K - \hat{y}_K)^2 \\&= (1 - \hat{y}_p)^2 + (\hat{y}_1^2 + \dots + \hat{y}_{p-1}^2 + \hat{y}_{p+1}^2 + \dots + \hat{y}_K^2)\end{aligned}
mm個樣本的損失爲
=1mi=1mLi \ell = \frac{1}{m} \sum_{i=1}^m L_i
對比交叉熵損失與均方誤差損失,只看單個樣本的損失即可,下面從兩個角度進行分析。

損失函數角度

損失函數是網絡學習的指揮棒,它引導着網絡學習的方向——能讓損失函數變小的參數就是好參數。

所以,損失函數的選擇和設計要能表達你希望模型具有的性質與傾向。

對比交叉熵和均方誤差損失,可以發現,兩者均在y^=y=1\hat{y} = y = 1時取得最小值0,但在實踐中y^p\hat{y}_p只會趨近於1而不是恰好等於1,在y^p<1\hat{y}_p < 1的情況下,

  • 交叉熵只與label類別有關,y^p\hat{y}_p越趨近於1越好
  • 均方誤差不僅與y^p\hat{y}_p有關,還與其他項有關,它希望y^1,,y^p1,y^p+1,,y^K\hat{y}_1, \dots, \hat{y}_{p-1}, \hat{y}_{p+1}, \dots, \hat{y}_K越平均越好,即在1y^pK1\frac{1-\hat{y}_p}{K-1}時取得最小值

分類問題中,對於類別之間的相關性,我們缺乏先驗。

雖然我們知道,與“狗”相比,“貓”和“老虎”之間的相似度更高,但是這種關係在樣本標記之初是難以量化的,所以label都是one hot。

在這個前提下,均方誤差損失可能會給出錯誤的指示,比如貓、老虎、狗的3分類問題,label爲[1,0,0][1, 0, 0],在均方誤差看來,預測爲[0.8,0.1,0.1][0.8, 0.1, 0.1]要比[0.8,0.15,0.05][0.8, 0.15, 0.05]要好,即認爲平均總比有傾向性要好,但這有悖我們的常識

對交叉熵損失,既然類別間複雜的相似度矩陣是難以量化的,索性只能關注樣本所屬的類別,只要y^p\hat{y}_p越接近於1就好,這顯示是更合理的。

softmax反向傳播角度

softmax的作用是將(,+)(-\infty, +\infty)的幾個實數映射到(0,1)(0,1)之間且之和爲1,以獲得某種概率解釋。

令softmax函數的輸入爲zz,輸出爲y^\hat{y},對結點pp有,
y^p=ezpk=1Kezk \hat{y}_p = \frac{e^{z_p}}{\sum_{k=1}^K e^{z_k}}
y^p\hat{y}_p不僅與zpz_p有關,還與{zkkp}\{z_k | k\neq p\}有關,這裏僅看$z_p $,則有
y^pzp=y^p(1y^p) \frac{\partial \hat{y}_p}{\partial z_p} = \hat{y}_p(1-\hat{y}_p)
y^p\hat{y}_p爲正確分類的概率,爲0時表示分類完全錯誤,越接近於1表示越正確。根據鏈式法則,按理來講,對與zpz_p相連的權重,損失函數的偏導會含有y^p(1y^p)\hat{y}_p(1-\hat{y}_p)這一因子項,y^p=0\hat{y}_p = 0分類錯誤,但偏導爲0,權重不會更新,這顯然不對——分類越錯誤越需要對權重進行更新

交叉熵損失
Ly^p=1y^p \frac{\partial L}{\partial \hat{y}_p} = -\frac{1}{\hat{y}_p}
則有
Lz^p=Ly^py^pzp=y^p1 \frac{\partial L}{\partial \hat{z}_p} = \frac{\partial L}{\partial \hat{y}_p} \cdot \frac{\partial \hat{y}_p}{\partial z_p} = \hat{y}_p - 1
恰好將y^p(1y^p)\hat{y}_p(1-\hat{y}_p)中的y^p\hat{y}_p消掉,避免了上述情形的發生,且y^p\hat{y}_p越接近於1,偏導越接近於0,即分類越正確越不需要更新權重,這與我們的期望相符。

而對均方誤差損失
Ly^p=2(1y^p)=2(y^p1) \frac{\partial L}{\partial \hat{y}_p} = -2(1-\hat{y}_p)=2(\hat{y}_p - 1)
則有,
Lz^p=Ly^py^pzp=2y^p(1y^p)2 \frac{\partial L}{\partial \hat{z}_p} = \frac{\partial L}{\partial \hat{y}_p} \cdot \frac{\partial \hat{y}_p}{\partial z_p} = -2 \hat{y}_p (1 - \hat{y}_p)^2
顯然,仍會發生上面所說的情況——y^p=0\hat{y}_p = 0分類錯誤,但不更新權重

綜上,對分類問題而言,無論從損失函數角度還是softmax反向傳播角度,交叉熵都比均方誤差要好。

參考

發佈了55 篇原創文章 · 獲贊 77 · 訪問量 9萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章