詳解pytorch中的max方法

實際上pytorch官方文檔中的應該是torch.max(input)方法,而本文要講的可能嚴格意義上不是torch中的,而是針對torch中的張量方法,即input.max(axis)[index]
其中input表示要求取最大值的張量,axis可以爲0(表示求取每列的最大值),也可以爲1(每行的最大值)。index爲0表示只返回最大值本身,爲1表示只返回最大值對應的索引。如下,其中axis可以省去:

a = torch.Tensor([[0,3,2],[4,0,0]])
print(a.max(axis=0)[0]) # tensor([4., 3., 2.]),即第一列爲[0 4]最大值爲4,第二列爲[3 0],依此類推
print(a.max(axis=0)[1]) # tensor([1, 0, 0]),索引也是列的索引
print(a.max(axis=1)[0]) # tensor([3., 4.]),取各行的最大值
print(a.max(axis=1)[1]) # tensor([1, 0]),對應的索引

應用

在求解強化學習中需要qmaxq_{max}對應的action時,通常是輸入一個張量即神經網絡算出的q值,然後輸出q值對應的索引,輸出的是int型,如下:

import torch
q = torch.Tensor([[0,3,2,1]])
action=q.max(1)[1].item() # .item()將只有一個元素的張量變爲對應的元素
action=q.max(1)[1].view(1,1).item() # 如果不放心可在前面加view方法shape成只有一個元素的張量
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章