pytorch : CrossEntropyLoss 應用於語義分割

原文鏈接:https://www.jianshu.com/p/a6131515ee1d

pytorch : CrossEntropyLoss 應用於語義分割

  • 作 者: 月牙眼的樓下小黑
  • 聯 系: zhanglf_tmac (Wechat)
  • 聲 明: 歡迎轉載本文中的圖片或文字,請說明出處

官方文檔中對 CrossEntropyLoss()的介紹:
在這裏插入圖片描述
其實: pytorch中的CrossEntropyLoss 是可以直接應用於語義分割任務的。

我們不妨假設一個分割網絡的輸出形狀爲: (channel = 3, width = 2, height = 2) ,即 2 x 2 分辨率的圖像,其中每個像素可能屬於 {0,1,2} 三類中的其中一類。

import torch
from torch import nn
from torch.autograd import Variable

input = Variable(torch.ones(1,3,2,2), requires_grad=True)
target = Variable(torch.LongTensor([[[0,1],[1,0]]]))
print('input:', input)
print('target:', target)

loss = nn.CrossEntropyLoss()
print('loss: ', loss(input, target))

結果:

input: tensor([[[[1., 1.],
          [1., 1.]],

         [[1., 1.],
          [1., 1.]],

         [[1., 1.],
          [1., 1.]]]], requires_grad=True)
target: tensor([[[0, 1],
         [1, 0]]])
loss:  tensor(1.0986, grad_fn=<NllLoss2DBackward>)

我們討論一下兩個細節:

問題1: 輸出的 loss 形狀爲什麼是 1x 1 ?
默認情況下,即 size_average = True, loss 會在每個 mini-batch(小批量) 上取平均值. 如果字段 size_average 被設置爲 False, loss 將會在每個 mini-batch(小批量) 上累加, 而不會取平均值.

那麼這個 mini_batch_size 等於幾呢? 在程序中,網絡輸出形狀爲 4-d Tensor: ( batch_size, channel, width, height)。 注意: mini_batch_size != batch_size, 而是: mini_batch_size = batch_size * width * height.

這非常好理解,因爲語義分割本質上是 pixel-level classification, 所以 mini_batch_size 就等於一個 batch 圖像中的 像素總數。
我們可以將上面代碼中 loss 參數 size_average 設爲 False , 做個簡單的驗證:

import torch
from torch import nn

input = Variable(torch.ones(1,3,2,2), requires_grad=True)
target = Variable(torch.LongTensor([[[0,1],[1,0]]]))
print('input:', input)
print('target:', target)

loss = nn.CrossEntropyLoss(size_average=False)
print('loss', loss(input, target))

此時輸出的 loss 值爲: 4.3944, 正好是 1.0986 的 1 x 2 x 2 倍。

問題2:如何得到每個 pixel 的 loss ?
只需將loss 參數 reduce 設爲 False 即可。若網絡輸出形狀爲 4-d Tensor: ( batch_size, channel, width, height), 此時 loss 函數會返回一個 3-d Tensor: (batch_size, width, height), 每個元素對應一個 pixel 的 loss 值。

import torch
from torch import nn

input = Variable(torch.ones(1,3,2,2), requires_grad=True)
target = Variable(torch.LongTensor([[[0,1],[1,0]]]))

loss = nn.CrossEntropyLoss(reduce=False)
print('loss: ', loss(input, target))

結果:

loss:  tensor([[[1.0986, 1.0986],
         [1.0986, 1.0986]]], grad_fn=<NllLoss2DBackward>)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章