歡迎來到theFlyer的博客—希望你有不一樣的感悟
前言:交叉熵損失函數。
1. 損失函數
機器學習算法都或多或少的依賴於對目標函數最大化或者最小化的過程,常常把最小化的函數稱爲損失函數,它主要用於衡量機器學習模型的預測能力。
損失函數可以看出模型的優劣,提供了優化的方向,但是沒有任何一種損失函數適用於所有的模型。損失函數的選取依賴於參數的數量、異常值、機器學習算法、梯度下降的效率、導數求取的難易和預測的置信度等若干方面。
2. 交叉熵
對數損失Log Loss ,也被稱爲交叉熵損失Cross-entropy Loss,是定義在概率分佈的基礎上的。它通常用於多項式(multinomia)logistic regression 和神經網絡,還有在期望極大化算法(expectation-maximization)的一些變體中。
對數損失用來度量分類器的預測輸出的概率分佈(predict_proba)和真實分佈的差異,而不是去比較離散的類標籤是否相同。
2.1任務爲二分類時
在二分類的時候,真實標籤集合爲:Y∈{0,1}, 而分類器預測得到的概率分佈:P = Pr(y=1)
那麼,每一個樣本的對數損失就是在給定真實樣本標籤的條件下,分類器的負對數似然函數,如下所示:
當某個樣本的真實標籤y=1時, ,分類器的預測概率p=Pr(y=1)的概率越小,則分類損失就越大;反之,分類器的預測概率p=Pr(y=1)的概率越大,則分類損失就越小。
對於真實標籤y=0, ,分類器的預測概率p=Pr(y=1)的概率越大,則損失越大。
例:預測爲貓的p=Pr(y=1)概率是0.8,真實標籤y=1;預測不是貓的1-p=Pr(y=0)概率是0.2,真實標籤爲0。
* | 是貓 | 不是貓 |
---|---|---|
標籤 | 1 | 0 |
預測 | 0.8 | 0.2 |
此時損失爲
2.2任務爲多元分類時
在多元分類的時候,假定有k個類,則類標籤集合就是labels=(1,2,3,…,k).如果第i個樣本的類標籤是k的話,就記爲 。採用one-hot記法。每個樣本的真實標籤就是一個one-hot向量,其中只有一個位置記爲1。
例:設共有5類,label =3時,one-hot形式如下
標籤 | one-hot |
---|---|
3 | 00100 |
N個樣本的真實類標籤就是一個N行K列的矩陣:Y
Y | class 0 | class1 | class1 |
---|---|---|---|
sample1 | 0 | 1 | 0 |
sample2 | 1 | 0 | 0 |
sample3 | 0 | 1 | 0 |
sample4 | 0 | 0 | 1 |
sample5 | 1 | 0 | 0 |
分類器對N個樣本的每一個樣本都會預測出它屬於每個類的概率,這樣的概率矩陣P就是N行K列的。
P | class 0 | class1 | class1 |
---|---|---|---|
sample1 | 0.2 | 0.7 | 0.1 |
sample2 | 0.5 | 0.2 | 0.3 |
sample3 | 0.3 | 0.4 | 0.3 |
sample4 | 0.2 | 0.3 | 0.5 |
sample5 | 0.3 | 0.3 | 0.4 |
整個樣本集合上分類器的對數損失就可以如下定義:
此時損失爲
2.3任務爲多標籤分類時
多標籤是在一種圖片有多個類別時,比如一張圖片同時有貓狗。
* | 貓 | 狗 | 兔 |
---|---|---|---|
標籤 | 1 | 1 | 0 |
預測 | 0.8 | 0.7 | 0.1 |
與之前不一樣的是,預測不再通過softmax計算,而是採用sigmoid把輸出限制到(0,1)。正因此預測值得加和不再是1。這裏交叉熵單獨對每一個類別計算,每一個類別有兩種可能的類別,即屬於這個類的概率或不屬於這個類的概率。
例:單張圖片損失計算可以爲
各類損失計算如下
對於整體損失可以用下式:
後記
人生如苦旅,我亦是行人。
個人公衆號