损失函数中的logits

PyTorch(tensorflow类似)的损失函数中,有一个(类)损失函数名字中带了with_logits. 而这里的logits指的是,该损失函数已经内部自带了计算logit的操作,无需在传入给这个loss函数之前手动使用sigmoid/softmax将之前网络的输入映射到[0,1]之间.

logit函数

其形式如下:

L(p)=p1pL(p)=\frac{p}{1-p}

该函数可以将输入范围在[0,1]之间的数值p映射到[,][-\infty,\infty].如果p=0.5,则函数值为0,p<0.5,则函数值为负值;如果p>0.5,则函数值为正值.

损失函数中的logits

而在损失函数中,如果其名称中带了with_logits则可以直接将之前网络的输出接到该损失函数中,不需要手动调用sigmoid(input)函数. 因为该损失函数中包含了诸如softmaxtsigmoid方法,会将输入其中的数值从[,][-\infty,\infty]映射到[0,1]之间.

官方示例

  • 名称中不带logits的损失函数:
>>> input = torch.randn((3, 2), requires_grad=True)
>>> target = torch.rand((3, 2), requires_grad=False)
>>> loss = F.binary_cross_entropy(F.sigmoid(input), target)
>>> loss.backward()

可以看到,用于二分类的binary_cross_entropy函数需要将输入先经过sigmoid处理,将其变换到[0,1]之间再计算loss,因为这里的target一般是[0,1]之间的数值.

  • 名字中带with_logits的损失函数:
>>> input = torch.randn(3, requires_grad=True)
>>> target = torch.empty(3).random_(2)
>>> loss = F.binary_cross_entropy_with_logits(input, target)
>>> loss.backward()

可以看到这里可以将前面网络的输出(即此处的input)之间送入loss函数中计算,无需手动调用sigmoid进行变换.

参考

TF里几种loss和注意事项
官方示例

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章