pytorch : CrossEntropyLoss 應用於語義分割
- 作 者: 月牙眼的樓下小黑
- 聯 系: zhanglf_tmac (Wechat)
- 聲 明: 歡迎轉載本文中的圖片或文字,請說明出處
官方文檔中對 CrossEntropyLoss()的介紹:
其實: pytorch中的CrossEntropyLoss 是可以直接應用於語義分割任務的。
我們不妨假設一個分割網絡的輸出形狀爲: (channel = 3, width = 2, height = 2) ,即 2 x 2 分辨率的圖像,其中每個像素可能屬於 {0,1,2} 三類中的其中一類。
import torch
from torch import nn
from torch.autograd import Variable
input = Variable(torch.ones(1,3,2,2), requires_grad=True)
target = Variable(torch.LongTensor([[[0,1],[1,0]]]))
print('input:', input)
print('target:', target)
loss = nn.CrossEntropyLoss()
print('loss: ', loss(input, target))
結果:
input: tensor([[[[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.]]]], requires_grad=True)
target: tensor([[[0, 1],
[1, 0]]])
loss: tensor(1.0986, grad_fn=<NllLoss2DBackward>)
我們討論一下兩個細節:
問題1: 輸出的 loss 形狀爲什麼是 1x 1 ?
默認情況下,即 size_average = True, loss 會在每個 mini-batch(小批量) 上取平均值. 如果字段 size_average 被設置爲 False, loss 將會在每個 mini-batch(小批量) 上累加, 而不會取平均值.
那麼這個 mini_batch_size 等於幾呢? 在程序中,網絡輸出形狀爲 4-d Tensor: ( batch_size, channel, width, height)。 注意: mini_batch_size != batch_size, 而是: mini_batch_size = batch_size * width * height.
這非常好理解,因爲語義分割本質上是 pixel-level classification, 所以 mini_batch_size 就等於一個 batch 圖像中的 像素總數。
我們可以將上面代碼中 loss 參數 size_average 設爲 False , 做個簡單的驗證:
import torch
from torch import nn
input = Variable(torch.ones(1,3,2,2), requires_grad=True)
target = Variable(torch.LongTensor([[[0,1],[1,0]]]))
print('input:', input)
print('target:', target)
loss = nn.CrossEntropyLoss(size_average=False)
print('loss', loss(input, target))
此時輸出的 loss 值爲: 4.3944, 正好是 1.0986 的 1 x 2 x 2 倍。
問題2:如何得到每個 pixel 的 loss ?
只需將loss 參數 reduce 設爲 False 即可。若網絡輸出形狀爲 4-d Tensor: ( batch_size, channel, width, height), 此時 loss 函數會返回一個 3-d Tensor: (batch_size, width, height), 每個元素對應一個 pixel 的 loss 值。
import torch
from torch import nn
input = Variable(torch.ones(1,3,2,2), requires_grad=True)
target = Variable(torch.LongTensor([[[0,1],[1,0]]]))
loss = nn.CrossEntropyLoss(reduce=False)
print('loss: ', loss(input, target))
結果:
loss: tensor([[[1.0986, 1.0986],
[1.0986, 1.0986]]], grad_fn=<NllLoss2DBackward>)