Pytorch中的CrossEntropyLoss()函數案例解讀和結合one-hot編碼計算Loss

使用Pytorch框架進行深度學習任務,特別是分類任務時,經常會用到如下:

import torch.nn as nn
criterion = nn.CrossEntropyLoss().cuda()
loss = criterion(output, target)

即使用torch.nn.CrossEntropyLoss()作爲損失函數。

那nn.CrossEntropyLoss()內部到底是啥??

nn.CrossEntropyLoss()torch.nn中包裝好的一個類,對應torch.nn.functional中的cross_entropy
此外,nn.CrossEntropyLoss()nn.logSoftmax()nn.NLLLoss()的整合(將兩者結合到一個類中)。

nn.logSoftmax()

定義如下:
在這裏插入圖片描述
從公式看,其實就是先softmax在log。

nn.NLLLoss()

定義如下:
在這裏插入圖片描述
此loss期望的target是類別的索引 (0 to N-1, where N = number of classes)。

例子1:

import torch.nn as nn
 m = nn.LogSoftmax()
 loss = nn.NLLLoss()
 # input is of size nBatch x nClasses = 3 x 5
 input = autograd.Variable(torch.randn(3, 5), requires_grad=True)
 # each element in target has to have 0 <= value < nclasses
 target = autograd.Variable(torch.LongTensor([1, 0, 4]))
 output = loss(m(input), target)

可以看到,nn.NLLLoss的輸入target是類別值,並不是one-hot編碼格式,這個要注意!!

nn.CrossEntropyLoss()

定義如下:
在這裏插入圖片描述
仔細看看公式,發現其實它就是nn.LogSoftmax() + nn.NLLLoss()
調用時輸入參數如下:

  • input : 模型輸出,包含每個類的得分,2-D tensor,shape爲 batch * n類
  • target: 大小爲 n 的 1—D tensor,包含類別的索引(0到 n-1)。
    注意CrossEntropyLoss()的target輸入也是類別值,不是one-hot編碼格式

例子2:

import torch.nn as nn
 loss = nn.CrossEntropyLoss()
 # input is of size nBatch x nClasses = 3 x 5
 input = autograd.Variable(torch.randn(3, 5), requires_grad=True)
 # each element in target has to have 0 <= value < nclasses
 target = autograd.Variable(torch.LongTensor([1, 0, 4]))
 output = loss(input, target)

例子1和例子2結果等價

如果是one-hot編碼該怎麼計算loss?


for images, target in train_loader:
    images, target = images.cuda(), target .cuda()
    N = target .size(0)
    # N 是batch-size大小
    # C is the number of classes.
    labels = torch.full(size=(N, C), fill_value=0).cuda()
    labels.scatter_(dim=1, index=torch.unsqueeze(target, dim=1), value=1)

    score = model(images)
    log_prob = torch.nn.functional.log_softmax(score, dim=1)
    loss = -torch.sum(log_prob * labels) / N
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

其中N是類別數目,labels是one-hot編碼格式的二維向量(2-D tensor)。
需要先將例子1,2的target轉爲one-hot形式labels。
該loss計算可以替代例子1和例子2的loss計算方式

上述計算案例如下:

import torch.nn as nn
import torch
from torch import autograd
import torch.nn.functional as F
# logsoft-max + NLLLoss
m = nn.LogSoftmax()
loss = nn.NLLLoss()
input = autograd.Variable(torch.randn(3, 5), requires_grad=True)
target = autograd.Variable(torch.LongTensor([1, 0, 4]))
output = loss(m(input), target)
print('logsoftmax + nllloss output is {}'.format(output))

# crossentripyloss
loss = nn.CrossEntropyLoss()
# input = autograd.Variable(torch.randn(3, 5), requires_grad=True)
target = autograd.Variable(torch.LongTensor([1, 0, 4]))
output = loss(input, target)
print('crossentropy output is {}'.format(output))


# one hot label loss
C = 5
target = autograd.Variable(torch.LongTensor([1, 0, 4]))
print('target is {}'.format(target))
N = target .size(0)
# N 是batch-size大小
# C is the number of classes.
labels = torch.full(size=(N, C), fill_value=0)
print('labels shape is {}'.format(labels.shape))
labels.scatter_(dim=1, index=torch.unsqueeze(target, dim=1), value=1)
print('labels is {}'.format(labels))

log_prob = torch.nn.functional.log_softmax(input, dim=1)
loss = -torch.sum(log_prob * labels) / N
print('N is {}'.format(N))
print('one-hot loss is {}'.format(loss))

結果如下:

logsoftmax + nllloss output is 3.005390167236328
crossentropy output is 3.005390167236328
target is tensor([1, 0, 4])
labels shape is torch.Size([3, 5])
labels is tensor([[0., 1., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1.]])
N is 3
one-hot loss is 3.005390167236328

可知相同的輸入下全部等價。

補充:

以及關於cross entropy有關的函數及在torch.nntorch.nn.functional中對應關係如下:
在這裏插入圖片描述
torch.nntorch.nn.functional的區別在於torch.nn中對應的函數其實就是對F裏的函數進行包裝的類。

參考

https://www.jianshu.com/p/6049dbc1b73f
https://www.cnblogs.com/marsggbo/p/10401215.html
https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch-nn/#loss-functions
https://blog.csdn.net/dss_dssssd/article/details/84036913

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章