torch.max的理解
關於這個API,他的名字是torch.max,根據名字不難意識到,它表示尋找最大值
在tf2.0中,對應API的名字是reduce_max,對於這個名字我一開始是無法理解,reduce的含義,但是根據一些代碼,確實有這樣的認識:我們要想找出一些最大值,是要拋棄一些非最大值,這就是一個緯度的縮減
chat is cheap,我們來看一下代碼
import torch
a=torch.stack([torch.randperm(10),torch.randperm(10),torch.randperm(10)])
print(a)
print(a.max(0)[0])
print(a.max(1)[0])
輸出結果:
分析:max(0),表示縮減第一個緯度,也就是在第一個緯度上尋找最大值,所以原本我們的張量大小是[3,10] ,縮減了第一個緯度之後就是[10],那麼max(1)也是同理的,縮減了第二個緯度之後就是[3]