PyTorch(tensorflow類似)的損失函數中,有一個(類)損失函數名字中帶了with_logits
. 而這裏的logits指的是,該損失函數已經內部自帶了計算logit的操作,無需在傳入給這個loss函數之前手動使用sigmoid/softmax
將之前網絡的輸入映射到[0,1]之間.
logit函數
其形式如下:
該函數可以將輸入範圍在[0,1]之間的數值p映射到.如果p=0.5,則函數值爲0,p<0.5,則函數值爲負值;如果p>0.5,則函數值爲正值.
損失函數中的logits
而在損失函數中,如果其名稱中帶了with_logits
則可以直接將之前網絡的輸出接到該損失函數中,不需要手動調用sigmoid(input)
函數. 因爲該損失函數中包含了諸如softmaxt
或sigmoid
方法,會將輸入其中的數值從映射到[0,1]之間.
官方示例
- 名稱中不帶
logits
的損失函數:
>>> input = torch.randn((3, 2), requires_grad=True)
>>> target = torch.rand((3, 2), requires_grad=False)
>>> loss = F.binary_cross_entropy(F.sigmoid(input), target)
>>> loss.backward()
可以看到,用於二分類的binary_cross_entropy
函數需要將輸入先經過sigmoid處理
,將其變換到[0,1]之間再計算loss,因爲這裏的target一般是[0,1]之間的數值.
- 名字中帶
with_logits
的損失函數:
>>> input = torch.randn(3, requires_grad=True)
>>> target = torch.empty(3).random_(2)
>>> loss = F.binary_cross_entropy_with_logits(input, target)
>>> loss.backward()
可以看到這裏可以將前面網絡的輸出(即此處的input)之間送入loss函數中計算,無需手動調用sigmoid
進行變換.