pytorch杂记-torch.max

torch.max的理解

关于这个API,他的名字是torch.max,根据名字不难意识到,它表示寻找最大值

在tf2.0中,对应API的名字是reduce_max,对于这个名字我一开始是无法理解,reduce的含义,但是根据一些代码,确实有这样的认识:我们要想找出一些最大值,是要抛弃一些非最大值,这就是一个纬度的缩减

chat is cheap,我们来看一下代码

import torch

a=torch.stack([torch.randperm(10),torch.randperm(10),torch.randperm(10)])

print(a)

print(a.max(0)[0])

print(a.max(1)[0])

输出结果:
在这里插入图片描述

分析:max(0),表示缩减第一个纬度,也就是在第一个纬度上寻找最大值,所以原本我们的张量大小是[3,10] ,缩减了第一个纬度之后就是[10],那么max(1)也是同理的,缩减了第二个纬度之后就是[3]

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章