torch 的 F.cross_entropy

torch中的交叉熵损失函数使用案例

import torch
import torch.nn.functional as F

pred = torch.randn(3, 5)
print(pred.shape)

target = torch.tensor([2, 3, 4]).long() # 需要是整数
print(target.shape)

# 交叉熵损失函数, 输入的参数是形状不一样的
# predict会在其内部进行softmax操作
loss = F.cross_entropy(pred, target)
loss.item()

结果为:

需要注意的是, 传入的参数形状是不同的, predict是softmax之前的, 另外y需要是整形的, int也行

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章