Pytorch學習(二十二)soft label的交叉熵loss的實現

總說

參考的鏈接:

  1. https://blog.csdn.net/tsyccnh
  2. https://www.zhihu.com/question/41252833/answer/140950659

先理解一下信息熵、交叉熵和相對熵

先找一下交叉熵的定義:
1)信息熵:編碼方案完美時,最短平均編碼長度的是多少。
2)交叉熵:編碼方案不一定完美時(由於對概率分佈的估計不一定正確),平均編碼長度的是多少。 平均編碼長度 = 最短平均編碼長度 + 一個增量
3)相對熵:編碼方案不一定完美時,平均編碼長度相對於最小值的增加值。(即上面那個增量)
(即,相對熵就是信息增益,就是KL散度

作者:張一山
鏈接:https://www.zhihu.com/question/41252833/answer/140950659
來源:知乎
著作權歸作者所有。商業轉載請聯繫作者獲得授權,非商業轉載請註明出處。

信息熵

對於(1)沒啥說的,就是如果某件事情概率已知,那麼直接 H(X)=i=1np(xi)log(p(xi))H(X)=-\sum_{i=1}^{n}p(x_i)log(p(x_i))
這個表示,對於事件X只有x1,,xi,,xnx_1, \cdot , x_i, \cdot, x_n種情況,而xix_i發生的概率是p(xi)p(x_i),那麼直接 plog(p)-plog(p) OK?稍微說一下原因,主要是,log(p(xi))log(p(x_i))代表就是xix_i事件的熵,也就是信息量。由於x1x_1x2x_2獨立同分布,兩個同時發生,那麼概率是p(x1)p(x2)p(x_1)p(x_2),信息量應該是累加的,也就是有H(p(x1)p(x2))=H(p(x1)+H(p(x2))H(p(x_1)p(x_2))=H(p(x_1)+H(p(x_2)), 所以H(x)H(x)是對數函數。底可以取 2,e,102, e, 10都行,問題不大。最後,pp一般是小數,所以既然是信息,那得正數才符合,所以加個負號在座的沒有意見吧~。。
總結:plog(p)-plog(p)就是熵。(當然實際是,累加,這樣寫只是爲了好記)

相對熵

相對熵就是 KL散度,你想想,經常情況下,我們無法得知pp的真實分佈,那麼我們預測一個qq,希望這個qq的分佈和pp能儘量接近,可以用KL散度作爲測度。想想挺有意思的,比如圖像之間的差異,可以用L2損失,PSNR,SSIM等等,這些都是一種測量方式,表示差異。那你描述概率分佈的差異呢?不能直接相減吧,可以用這個(當然還有很多其他描述方式,比如Wasserstein distance來替代KL散度,引申出了WGAN)。

其中KL散度定義如下:
DKL(pq)=i=1np(xi)log(p(xi)q(xi)D_{KL}(p||q)=\sum_{i=1}^{n}p(x_i)log(\frac{p(x_i)}{q(x_i)}
在分類中,P表示真實分佈,Q是預測的分佈。那你自然要Q的分佈儘量接近P的。訓練一次,得到如果用Q分佈來表示P,需要額外的信息是DKL(pq)D_{KL}(p||q),更新一次,讓這個增益儘量減少。所以分類就是讓q(x)q(x)儘量接近p(x)p(x)

交叉熵

用眼睛可以看出:
DKL(pq)=H(p(x))+[i=1np(xi)log(q(xi))]D_{KL}(p||q)=-H(p(x)) +[-\sum_{i=1}^{n}p(x_i)log(q(x_i))]
後面就是交叉熵。
因爲H(p(x))-H(p(x))是不變的,表示訓練集的自然規律吧(比如這樣的圖片, 它是貓的概率是p(x=貓)即p(x1)p(x_1),是狗的p(x=)p(x=狗)p(x2)p(x_2))。
再來看看交叉熵:

H(p,q)=i=1np(xi)log(q(xi))H(p,q) = -\sum_{i=1}^{n}p(x_i)log(q(x_i))
表示的是,用估計的概率qq來編碼,需要的編碼長度。

對比一下:熵和交叉熵的形式(只是爲了方便記憶):
H(X)=i=1np(xi)log(p(xi))H(X)=-\sum_{i=1}^{n}p(x_i)log(p(x_i))
H(p,q)=i=1np(xi)log(q(xi))H(p,q) = -\sum_{i=1}^{n}p(x_i)log(q(x_i))
**其實都是plog(p)-plog(p)**形式,當不知道pp的概率時,就用估計的qq來替代一下,塞進loglog中。

分類中的交叉熵

熟練寫出交叉熵後,我們來看看,沒啥毛病,q(xi)q(x_i)就是當輸入這張圖片II時,網絡的輸出的概率(經過softmax後)
H(p,q)=i=1np(xi)log(q(xi))H(p,q) = -\sum_{i=1}^{n}p(x_i)log(q(x_i))
如果這個圖片是貓,那麼p(I=)=p(x1)=1p(I=貓)=p(x_1)=1,其他的概率爲0.

那其實就很簡單了,所以普通的交叉熵的計算如下:
loss=i=1Kyilog(yi^)loss = -\sum_{i=1}^{K}y_{i}log(\hat{y_i})
表示,輸入II的時候,這張圖片與真實分佈的損失。(就是交叉熵,log種的那個每一類的概率用預測的就行。)
如果是普通的分類,每張圖片就一個損失值,就是 log()-log(預測到了正確的類別的概率), 因爲真實標籤一般採用hard label,就是p()=1p(該圖屬於正確類別)=1, p()=0p(該圖屬於其他類別)=0.
當然正規寫的時候:
loss=1Nj=1Ni=1nyjilog(yji^)loss = -\frac{1}{N}\sum_{j=1}^{N}\sum_{i=1}^{n}y_{ji}log(\hat{y_{ji}})
其中NN是batch的圖片數。

SoftCrossEntropy

其實就是按照公式來就行:
loss=1Nj=1Ni=1nyjilog(yji^)loss = -\frac{1}{N}\sum_{j=1}^{N}\sum_{i=1}^{n}y_{ji}log(\hat{y_{ji}})
只不過普通的看上去好像是每張圖片有KK個類別的值相加,實際上只有1個值。
如果是soft的話,就是真的是KK個值相加了。

於是,有了下面這個:

import torch
import torch.nn.functional as F


def SoftCrossEntropy(inputs, target, reduction='sum'):
    log_likelihood = -F.log_softmax(inputs, dim=1)
    batch = inputs.shape[0]
    if reduction == 'average':
        loss = torch.sum(torch.mul(log_likelihood, target)) / batch
    else:
        loss = torch.sum(torch.mul(log_likelihood, target))
    return loss

注意點: target是已經經過softmax歸一化後的值,即表示真實概率yjiy_ji。如果是2分類,則i=1i=2i=1或者i=2。 而inputs是網絡的直接輸出(卷積層或是fc的輸出,沒有經過softmax),所以log(q)-log(q)啊,所以這裏用-F.log_softmax。當然,最後plog(q)-plog(q),直接和target相乘就行。

普通的分類就這樣做,如果是SSD之類的,他其實是對每個feature的點(對應pixel level上的小框)進行分類。比如inputsM*C的,其中M=N*out_dim。也就是說,假設網絡輸出是out_dim維度(比如ssd,2分類,網絡輸出8000多個預選礦),N是batchsize,那麼直接前面兩個維度直接合並好不,每個點(N*out_dim這麼多個feature點(或者說是小框))都要進行分類。就相當於每個小框的分類一樣。

附加:二分類和多分類的區別

最簡單的分類是二分類,這裏說的二分類是指,是或者不是這個類。如果是二分類的話,其實最後一個神經元就行了,這時候就用 BCEWithLogitsLoss()或者nn.Sigmoid後面再加上BCELoss。這時候就沒必要用softmax進行歸一化了。如果是有A和B兩種類別,那麼最後還是要2個神經元。多分類還是秉承LogSoftmax()後面接NLLLoss()

>>> # 2D loss example (used, for example, with image inputs)
>>> N, C = 5, 4
>>> loss = nn.NLLLoss()
>>> # input is of size N x C x height x width
>>> data = torch.randn(N, 16, 10, 10)
>>> conv = nn.Conv2d(16, C, (3, 3))
>>> m = nn.LogSoftmax(dim=1)
>>> # each element in target has to have 0 <= value < C
>>> target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
>>> output = loss(m(conv(data)), target)
>>> output.backward()

其中網絡的輸出的最大值(LogSoftMax之前就行),就是這個圖像的類別。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章