正確格式:
1.
data:3,3,3,2
label:3,3,2
2.data:3:2
label:2
3.data:2,1,6
label:2,6
Traceback (most recent call last):
torch.Size([3, 3, 3, 2]) torch.Size([3, 2])
File "F:/project/chedaoxian/Ultra-Fast-bar-Detection/utils/loss.py", line 91, in <module>
out_size, target.size()))
ValueError: Expected target size (3, 3, 2), got torch.Size([3, 2])
data維度 [2,6]
label維度[2],這樣纔可以,增加維度就報錯。
如果data是[2,1,6]
label就需要是[2,6]
import torch
import torch.nn as nn
import torch.nn.functional as F
data = torch.randn(2,6)
target = torch.tensor([1,2])
print('data:', data.size(),target.size())
entropy_out = F.cross_entropy(data, target)
print('entropy_out:', entropy_out, target.size())
log_soft = F.log_softmax(data, dim=1)
print('log_soft:', log_soft, '\n')
nll_out = F.nll_loss(log_soft, target)
print('nll_out:', nll_out)