1. 數據類型不匹配:
報錯:Expected object of type torch.LongTensor but found type torch.FloatTensor for argument #2 ‘target’
criterion = nn.CrossEntropyLoss()
loss = criterion(y_pre, y_train)
這裏的y_train類型一定要是LongTensor的,所以在寫DataSet的時候返回的label就要是LongTensor類型的,如下
def__init__(self, ...):
self.label = torch.LongTensor(label)
2.target要用類標
報錯:multi-target not supported at c:\new-builder_2\win-wheel\pytorch\aten\src\thnn\generic/ClassNLLCriterion.c:21
criterion = nn.CrossEntropyLoss()
loss = criterion(y_pre, y_train)
這裏的y_train不能用one-hot編碼,要用類標值。