pytorch学习——torch.max

torch.max学习记录:

首先定义数据:

import torch
a = torch.randn(2,3)
print("a:",a)

结果如下:

a: tensor([[-0.5658, -0.9736, -1.1753],
        [ 1.2006,  0.4078, -2.0542]])

 

torch.max(inputdim

按维度dim 返回最大值

torch.max(a,dim=0) 返回值为一个元组,元组里包含两个,第一个值为一个每一列中最大元素,第二个值为最大元素在这一列的行索引

返回的又是列又是行的,这句话怎么理解呢?

可以理解成:dim=0,第0个维度表示行,可以想象是你的手,从上往下挤压(对应dim=0,第一行,第二行...从上往下的一个行方向),直到压扁的一个过程(这个要理解清楚)。在此过程中,只保存每一列的最大值,同时记录下这个最大值是第几行(即行索引)。

       [ [-0.5658, -0.9736, -1.1753],
        [ 1.2006,  0.4078, -2.0542 ] ]

以这个数据来说,Size[2,3] ,使用torch.max(a,dim=0), 

结果为:

(tensor([ 1.2006,  0.4078, -1.1753]), tensor([1, 1, 0]))

想象有只手从上往下压,把它压扁,保留每一列最大值[ 1.2006, 0.4078, -1.1753] , 同时保留行索引。1.2006在第一行,0.4078在第一行,-1.1753在第0行,对应后面的 tensor([1, 1, 0]。

也可以通过索引,取出元组中的结果:

print("a.max(dim=0)[0]:",a.max(dim=0)[0])
print("a.max(dim=0)[1]:",a.max(dim=0)[1])

输出:

a.max(dim=0)[0]: tensor([ 1.2006,  0.4078, -1.1753])
a.max(dim=0)[1]: tensor([1, 1, 0])

 

torch.max(a,diim=1) 返回每一行中最大值的那个元素,且返回其索引(返回最大元素在这一行的列索引

同理,可以想象为是有只手从左往右挤压(对应dim=1,第一列,第二列...从左往右的一个列方向),直到压扁的一个过程

贴上代码:

print("a.max(dim=1)[0]:",a.max(dim=1)[0])
print("a.max(dim=1)[1]:",a.max(dim=1)[1])

输出结果:

a.max(dim=1)[0]: tensor([-0.5658,  1.2006])
a.max(dim=1)[1]: tensor([0, 0])

torch.max与numpy.max的用法类似。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章