pytorch criterion踩坑小結

1. 數據類型不匹配:

報錯:Expected object of type torch.LongTensor but found type torch.FloatTensor for argument #2 ‘target’

criterion = nn.CrossEntropyLoss()
loss = criterion(y_pre, y_train)

這裏的y_train類型一定要是LongTensor的,所以在寫DataSet的時候返回的label就要是LongTensor類型的,如下

def__init__(self, ...):
	self.label = torch.LongTensor(label)

2.target要用類標

報錯:multi-target not supported at c:\new-builder_2\win-wheel\pytorch\aten\src\thnn\generic/ClassNLLCriterion.c:21

criterion = nn.CrossEntropyLoss()
loss = criterion(y_pre, y_train)

這裏的y_train不能用one-hot編碼,要用類標值。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章