torch nll_loss

 

正確格式:

1.

data:3,3,3,2

label:3,3,2

2.data:3:2

label:2

3.data:2,1,6

label:2,6

 

Traceback (most recent call last):
torch.Size([3, 3, 3, 2]) torch.Size([3, 2])
  File "F:/project/chedaoxian/Ultra-Fast-bar-Detection/utils/loss.py", line 91, in <module>
    out_size, target.size()))
ValueError: Expected target size (3, 3, 2), got torch.Size([3, 2])



 

data維度 [2,6]

label維度[2],這樣纔可以,增加維度就報錯。

如果data是[2,1,6]

label就需要是[2,6]

  import torch
    import torch.nn as nn
    import torch.nn.functional as F

    data = torch.randn(2,6)

    target = torch.tensor([1,2])

    print('data:', data.size(),target.size())
    entropy_out = F.cross_entropy(data, target)
    print('entropy_out:', entropy_out, target.size())
    log_soft = F.log_softmax(data, dim=1)
    print('log_soft:', log_soft, '\n')
    nll_out = F.nll_loss(log_soft, target)


    print('nll_out:', nll_out)

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章