torch 的 F.cross_entropy

torch中的交叉熵損失函數使用案例

import torch
import torch.nn.functional as F

pred = torch.randn(3, 5)
print(pred.shape)

target = torch.tensor([2, 3, 4]).long() # 需要是整數
print(target.shape)

# 交叉熵損失函數, 輸入的參數是形狀不一樣的
# predict會在其內部進行softmax操作
loss = F.cross_entropy(pred, target)
loss.item()

結果爲:

需要注意的是, 傳入的參數形狀是不同的, predict是softmax之前的, 另外y需要是整形的, int也行

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章