torch.max学习记录:
首先定义数据:
import torch
a = torch.randn(2,3)
print("a:",a)
结果如下:
a: tensor([[-0.5658, -0.9736, -1.1753],
[ 1.2006, 0.4078, -2.0542]])
torch.
max
(input, dim)
按维度dim 返回最大值
torch.max(a,dim=0) 返回值为一个元组,元组里包含两个值,第一个值为一个每一列中最大元素,第二个值为最大元素在这一列的行索引
返回的又是列又是行的,这句话怎么理解呢?
可以理解成:dim=0,第0个维度表示行,可以想象是你的手,从上往下挤压(对应dim=0,第一行,第二行...从上往下的一个行方向),直到压扁的一个过程(这个要理解清楚)。在此过程中,只保存每一列的最大值,同时记录下这个最大值是第几行(即行索引)。
[ [-0.5658, -0.9736, -1.1753],
[ 1.2006, 0.4078, -2.0542 ] ]
以这个数据来说,Size[2,3] ,使用torch.max(a,dim=0),
结果为:
(tensor([ 1.2006, 0.4078, -1.1753]), tensor([1, 1, 0]))
想象有只手从上往下压,把它压扁,保留每一列最大值[ 1.2006, 0.4078, -1.1753] , 同时保留行索引。1.2006在第一行,0.4078在第一行,-1.1753在第0行,对应后面的 tensor([1, 1, 0]。
也可以通过索引,取出元组中的结果:
print("a.max(dim=0)[0]:",a.max(dim=0)[0])
print("a.max(dim=0)[1]:",a.max(dim=0)[1])
输出:
a.max(dim=0)[0]: tensor([ 1.2006, 0.4078, -1.1753])
a.max(dim=0)[1]: tensor([1, 1, 0])
torch.max(a,diim=1) 返回每一行中最大值的那个元素,且返回其索引(返回最大元素在这一行的列索引)
同理,可以想象为是有只手从左往右挤压(对应dim=1,第一列,第二列...从左往右的一个列方向),直到压扁的一个过程
贴上代码:
print("a.max(dim=1)[0]:",a.max(dim=1)[0])
print("a.max(dim=1)[1]:",a.max(dim=1)[1])
输出结果:
a.max(dim=1)[0]: tensor([-0.5658, 1.2006])
a.max(dim=1)[1]: tensor([0, 0])
torch.max与numpy.max的用法类似。