Pytorch 加权交叉熵实现分析

此文不涉及公式。

假设模型对输入图像数据(1(batch size) * 1 (channels) * 3 (height) * 3 (width))的分割输出结果为1 * 2 (只有两类,前景,背景)* 3 (height)* 3 (width),ground truth为 target ( 1 (batch size) * 3 * 3)。

 假设模型的分割结果为input, ground truth 为target。

根据文献[1],计算交叉熵损失有两中方式,一种是用F.nll_loss();一种是用F.cross_entropy(input, target)。

(1)利用F.cross_entropy()计算未加权的结果:

(2)利用F.cross_entropy()中的weight参数计算加权结果:

对比以上两个结果可以发现:加权后的交叉熵损失比未加权的交叉熵损失小。(假设加权后的交叉熵损失比未加权的交叉熵损失大,那pytorch的加权交叉熵实现是错误的?还需要找一些关于加权加权交叉熵的数学公式)继续分析,pytorch是如何实现加权交叉熵?

(3)对F.cross_entropy()中的reduce参数设置为False(关于cross_entropy中参数详细介绍,请参考文献[2])。

分析发现,(3)的结果与(1)的结果相同,则reduce=True的作用就是对batch_size * height * width 个像素点的交叉熵损失的均值。

(4)如果按照(3)的理解,对加权的F.cross_entropy的reduce的参数设置False,结果:

分析发现,该结果与(2)结果不一致。那么这个加权是怎么实现的?根据文献【3】, 首先,对batch_size * height * width 个像素点的加权交叉熵损失进行求和得到sum之后;然后,计算出batch_size * height * width 个像素点对应类别权重之和weight_sum,例如:

target中有6个类别为1的像素点,3个类别为0的像素点, 这些像素点对应类别权重之和为:6 * 10 + 3 * 1 = 63。

重新计算损失值:

个人理解:F.cross_entropy()中的weight参数作用:将每个类别的像素点的数量扩大weight倍。

参考文献:

【1】https://blog.csdn.net/qq_22210253/article/details/85229988

【2】https://pytorch.org/docs/stable/nn.html#torch.nn.CrossEntropyLoss

【3】https://discuss.pytorch.org/t/how-to-use-the-weight-parameter-for-f-cross-entropy-correctly/17786

 代码:

  https://colab.research.google.com/drive/1VnZ53pwoLB3S0rGb7N-crgErLlxIR_VS#scrollTo=aCvteE5ZAx1Q

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章