損失函數中的logits

PyTorch(tensorflow類似)的損失函數中,有一個(類)損失函數名字中帶了with_logits. 而這裏的logits指的是,該損失函數已經內部自帶了計算logit的操作,無需在傳入給這個loss函數之前手動使用sigmoid/softmax將之前網絡的輸入映射到[0,1]之間.

logit函數

其形式如下:

L(p)=p1pL(p)=\frac{p}{1-p}

該函數可以將輸入範圍在[0,1]之間的數值p映射到[,][-\infty,\infty].如果p=0.5,則函數值爲0,p<0.5,則函數值爲負值;如果p>0.5,則函數值爲正值.

損失函數中的logits

而在損失函數中,如果其名稱中帶了with_logits則可以直接將之前網絡的輸出接到該損失函數中,不需要手動調用sigmoid(input)函數. 因爲該損失函數中包含了諸如softmaxtsigmoid方法,會將輸入其中的數值從[,][-\infty,\infty]映射到[0,1]之間.

官方示例

  • 名稱中不帶logits的損失函數:
>>> input = torch.randn((3, 2), requires_grad=True)
>>> target = torch.rand((3, 2), requires_grad=False)
>>> loss = F.binary_cross_entropy(F.sigmoid(input), target)
>>> loss.backward()

可以看到,用於二分類的binary_cross_entropy函數需要將輸入先經過sigmoid處理,將其變換到[0,1]之間再計算loss,因爲這裏的target一般是[0,1]之間的數值.

  • 名字中帶with_logits的損失函數:
>>> input = torch.randn(3, requires_grad=True)
>>> target = torch.empty(3).random_(2)
>>> loss = F.binary_cross_entropy_with_logits(input, target)
>>> loss.backward()

可以看到這裏可以將前面網絡的輸出(即此處的input)之間送入loss函數中計算,無需手動調用sigmoid進行變換.

參考

TF裏幾種loss和注意事項
官方示例

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章