torch max()函數

torch.max()返回的是兩個值, 第一個是最大值, 第二個是最大值所在的索引, 一般情況,我們都是求最大值所在的索引

import torch

a = torch.tensor([[1, 5, 2, 1], [2, 6, 3, 8]])
print(a)

res, index = torch.max(a, 1)
print(res)
print(index)

只用最大值索引求準確率:

# 準確率的計算  
# 100個樣本, 10 個類別
predict = torch.rand(100, 10)
label = torch.randint(10, (100,), dtype=torch.int64)

pred_y = torch.max(predict, 1)[1].numpy()
y_label = label.numpy()

accuracy = (pred_y == y_label).sum() / len(y_label)
print("準確率:", accuracy)

結果爲

準確率: 0.21

這裏是取的隨機數, 結果不重要

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章