torch中的交叉熵損失函數使用案例
import torch
import torch.nn.functional as F
pred = torch.randn(3, 5)
print(pred.shape)
target = torch.tensor([2, 3, 4]).long() # 需要是整數
print(target.shape)
# 交叉熵損失函數, 輸入的參數是形狀不一樣的
# predict會在其內部進行softmax操作
loss = F.cross_entropy(pred, target)
loss.item()
結果爲:
需要注意的是, 傳入的參數形狀是不同的, predict是softmax之前的, 另外y需要是整形的, int也行