Pytorch学习(二十二)soft label的交叉熵loss的实现

总说

参考的链接:

  1. https://blog.csdn.net/tsyccnh
  2. https://www.zhihu.com/question/41252833/answer/140950659

先理解一下信息熵、交叉熵和相对熵

先找一下交叉熵的定义:
1)信息熵:编码方案完美时,最短平均编码长度的是多少。
2)交叉熵:编码方案不一定完美时(由于对概率分布的估计不一定正确),平均编码长度的是多少。 平均编码长度 = 最短平均编码长度 + 一个增量
3)相对熵:编码方案不一定完美时,平均编码长度相对于最小值的增加值。(即上面那个增量)
(即,相对熵就是信息增益,就是KL散度

作者:张一山
链接:https://www.zhihu.com/question/41252833/answer/140950659
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

信息熵

对于(1)没啥说的,就是如果某件事情概率已知,那么直接 H(X)=i=1np(xi)log(p(xi))H(X)=-\sum_{i=1}^{n}p(x_i)log(p(x_i))
这个表示,对于事件X只有x1,,xi,,xnx_1, \cdot , x_i, \cdot, x_n种情况,而xix_i发生的概率是p(xi)p(x_i),那么直接 plog(p)-plog(p) OK?稍微说一下原因,主要是,log(p(xi))log(p(x_i))代表就是xix_i事件的熵,也就是信息量。由于x1x_1x2x_2独立同分布,两个同时发生,那么概率是p(x1)p(x2)p(x_1)p(x_2),信息量应该是累加的,也就是有H(p(x1)p(x2))=H(p(x1)+H(p(x2))H(p(x_1)p(x_2))=H(p(x_1)+H(p(x_2)), 所以H(x)H(x)是对数函数。底可以取 2,e,102, e, 10都行,问题不大。最后,pp一般是小数,所以既然是信息,那得正数才符合,所以加个负号在座的没有意见吧~。。
总结:plog(p)-plog(p)就是熵。(当然实际是,累加,这样写只是为了好记)

相对熵

相对熵就是 KL散度,你想想,经常情况下,我们无法得知pp的真实分布,那么我们预测一个qq,希望这个qq的分布和pp能尽量接近,可以用KL散度作为测度。想想挺有意思的,比如图像之间的差异,可以用L2损失,PSNR,SSIM等等,这些都是一种测量方式,表示差异。那你描述概率分布的差异呢?不能直接相减吧,可以用这个(当然还有很多其他描述方式,比如Wasserstein distance来替代KL散度,引申出了WGAN)。

其中KL散度定义如下:
DKL(pq)=i=1np(xi)log(p(xi)q(xi)D_{KL}(p||q)=\sum_{i=1}^{n}p(x_i)log(\frac{p(x_i)}{q(x_i)}
在分类中,P表示真实分布,Q是预测的分布。那你自然要Q的分布尽量接近P的。训练一次,得到如果用Q分布来表示P,需要额外的信息是DKL(pq)D_{KL}(p||q),更新一次,让这个增益尽量减少。所以分类就是让q(x)q(x)尽量接近p(x)p(x)

交叉熵

用眼睛可以看出:
DKL(pq)=H(p(x))+[i=1np(xi)log(q(xi))]D_{KL}(p||q)=-H(p(x)) +[-\sum_{i=1}^{n}p(x_i)log(q(x_i))]
后面就是交叉熵。
因为H(p(x))-H(p(x))是不变的,表示训练集的自然规律吧(比如这样的图片, 它是猫的概率是p(x=猫)即p(x1)p(x_1),是狗的p(x=)p(x=狗)p(x2)p(x_2))。
再来看看交叉熵:

H(p,q)=i=1np(xi)log(q(xi))H(p,q) = -\sum_{i=1}^{n}p(x_i)log(q(x_i))
表示的是,用估计的概率qq来编码,需要的编码长度。

对比一下:熵和交叉熵的形式(只是为了方便记忆):
H(X)=i=1np(xi)log(p(xi))H(X)=-\sum_{i=1}^{n}p(x_i)log(p(x_i))
H(p,q)=i=1np(xi)log(q(xi))H(p,q) = -\sum_{i=1}^{n}p(x_i)log(q(x_i))
**其实都是plog(p)-plog(p)**形式,当不知道pp的概率时,就用估计的qq来替代一下,塞进loglog中。

分类中的交叉熵

熟练写出交叉熵后,我们来看看,没啥毛病,q(xi)q(x_i)就是当输入这张图片II时,网络的输出的概率(经过softmax后)
H(p,q)=i=1np(xi)log(q(xi))H(p,q) = -\sum_{i=1}^{n}p(x_i)log(q(x_i))
如果这个图片是猫,那么p(I=)=p(x1)=1p(I=猫)=p(x_1)=1,其他的概率为0.

那其实就很简单了,所以普通的交叉熵的计算如下:
loss=i=1Kyilog(yi^)loss = -\sum_{i=1}^{K}y_{i}log(\hat{y_i})
表示,输入II的时候,这张图片与真实分布的损失。(就是交叉熵,log种的那个每一类的概率用预测的就行。)
如果是普通的分类,每张图片就一个损失值,就是 log()-log(预测到了正确的类别的概率), 因为真实标签一般采用hard label,就是p()=1p(该图属于正确类别)=1, p()=0p(该图属于其他类别)=0.
当然正规写的时候:
loss=1Nj=1Ni=1nyjilog(yji^)loss = -\frac{1}{N}\sum_{j=1}^{N}\sum_{i=1}^{n}y_{ji}log(\hat{y_{ji}})
其中NN是batch的图片数。

SoftCrossEntropy

其实就是按照公式来就行:
loss=1Nj=1Ni=1nyjilog(yji^)loss = -\frac{1}{N}\sum_{j=1}^{N}\sum_{i=1}^{n}y_{ji}log(\hat{y_{ji}})
只不过普通的看上去好像是每张图片有KK个类别的值相加,实际上只有1个值。
如果是soft的话,就是真的是KK个值相加了。

于是,有了下面这个:

import torch
import torch.nn.functional as F


def SoftCrossEntropy(inputs, target, reduction='sum'):
    log_likelihood = -F.log_softmax(inputs, dim=1)
    batch = inputs.shape[0]
    if reduction == 'average':
        loss = torch.sum(torch.mul(log_likelihood, target)) / batch
    else:
        loss = torch.sum(torch.mul(log_likelihood, target))
    return loss

注意点: target是已经经过softmax归一化后的值,即表示真实概率yjiy_ji。如果是2分类,则i=1i=2i=1或者i=2。 而inputs是网络的直接输出(卷积层或是fc的输出,没有经过softmax),所以log(q)-log(q)啊,所以这里用-F.log_softmax。当然,最后plog(q)-plog(q),直接和target相乘就行。

普通的分类就这样做,如果是SSD之类的,他其实是对每个feature的点(对应pixel level上的小框)进行分类。比如inputsM*C的,其中M=N*out_dim。也就是说,假设网络输出是out_dim维度(比如ssd,2分类,网络输出8000多个预选矿),N是batchsize,那么直接前面两个维度直接合并好不,每个点(N*out_dim这么多个feature点(或者说是小框))都要进行分类。就相当于每个小框的分类一样。

附加:二分类和多分类的区别

最简单的分类是二分类,这里说的二分类是指,是或者不是这个类。如果是二分类的话,其实最后一个神经元就行了,这时候就用 BCEWithLogitsLoss()或者nn.Sigmoid后面再加上BCELoss。这时候就没必要用softmax进行归一化了。如果是有A和B两种类别,那么最后还是要2个神经元。多分类还是秉承LogSoftmax()后面接NLLLoss()

>>> # 2D loss example (used, for example, with image inputs)
>>> N, C = 5, 4
>>> loss = nn.NLLLoss()
>>> # input is of size N x C x height x width
>>> data = torch.randn(N, 16, 10, 10)
>>> conv = nn.Conv2d(16, C, (3, 3))
>>> m = nn.LogSoftmax(dim=1)
>>> # each element in target has to have 0 <= value < C
>>> target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
>>> output = loss(m(conv(data)), target)
>>> output.backward()

其中网络的输出的最大值(LogSoftMax之前就行),就是这个图像的类别。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章