博客:
blog.shinelee.me |
博客園 |
CSDN
交叉熵損失與均方誤差損失
常規分類網絡最後的softmax層如下圖所示,傳統機器學習方法以此類比,
一共有K類,令網絡的輸出爲[y^1,…,y^K],對應每個類別的概率,令label爲 [y1,…,yK]。對某個屬於p類的樣本,其label中yp=1,y1,…,yp−1,yp+1,…,yK均爲0。
對這個樣本,交叉熵(cross entropy)損失爲
L=−(y1logy^1+⋯+yKlogy^K)=−yplogy^p=−logy^p
**均方誤差損失(mean squared error,MSE)**爲
L=(y1−y^1)2+⋯+(yK−y^K)2=(1−y^p)2+(y^12+⋯+y^p−12+y^p+12+⋯+y^K2)
則m個樣本的損失爲
ℓ=m1i=1∑mLi
對比交叉熵損失與均方誤差損失,只看單個樣本的損失即可,下面從兩個角度進行分析。
損失函數角度
損失函數是網絡學習的指揮棒,它引導着網絡學習的方向——能讓損失函數變小的參數就是好參數。
所以,損失函數的選擇和設計要能表達你希望模型具有的性質與傾向。
對比交叉熵和均方誤差損失,可以發現,兩者均在y^=y=1時取得最小值0,但在實踐中y^p只會趨近於1而不是恰好等於1,在y^p<1的情況下,
- 交叉熵只與label類別有關,y^p越趨近於1越好
- 均方誤差不僅與y^p有關,還與其他項有關,它希望y^1,…,y^p−1,y^p+1,…,y^K越平均越好,即在K−11−y^p時取得最小值
分類問題中,對於類別之間的相關性,我們缺乏先驗。
雖然我們知道,與“狗”相比,“貓”和“老虎”之間的相似度更高,但是這種關係在樣本標記之初是難以量化的,所以label都是one hot。
在這個前提下,均方誤差損失可能會給出錯誤的指示,比如貓、老虎、狗的3分類問題,label爲[1,0,0],在均方誤差看來,預測爲[0.8,0.1,0.1]要比[0.8,0.15,0.05]要好,即認爲平均總比有傾向性要好,但這有悖我們的常識。
而對交叉熵損失,既然類別間複雜的相似度矩陣是難以量化的,索性只能關注樣本所屬的類別,只要y^p越接近於1就好,這顯示是更合理的。
softmax反向傳播角度
softmax的作用是將(−∞,+∞)的幾個實數映射到(0,1)之間且之和爲1,以獲得某種概率解釋。
令softmax函數的輸入爲z,輸出爲y^,對結點p有,
y^p=∑k=1Kezkezp
y^p不僅與zp有關,還與{zk∣k=p}有關,這裏僅看$z_p $,則有
∂zp∂y^p=y^p(1−y^p)
y^p爲正確分類的概率,爲0時表示分類完全錯誤,越接近於1表示越正確。根據鏈式法則,按理來講,對與zp相連的權重,損失函數的偏導會含有y^p(1−y^p)這一因子項,y^p=0時分類錯誤,但偏導爲0,權重不會更新,這顯然不對——分類越錯誤越需要對權重進行更新。
對交叉熵損失,
∂y^p∂L=−y^p1
則有
∂z^p∂L=∂y^p∂L⋅∂zp∂y^p=y^p−1
恰好將y^p(1−y^p)中的y^p消掉,避免了上述情形的發生,且y^p越接近於1,偏導越接近於0,即分類越正確越不需要更新權重,這與我們的期望相符。
而對均方誤差損失,
∂y^p∂L=−2(1−y^p)=2(y^p−1)
則有,
∂z^p∂L=∂y^p∂L⋅∂zp∂y^p=−2y^p(1−y^p)2
顯然,仍會發生上面所說的情況——y^p=0,分類錯誤,但不更新權重。
綜上,對分類問題而言,無論從損失函數角度還是softmax反向傳播角度,交叉熵都比均方誤差要好。
參考