torch.max总结

torch.max被广泛应用在评价model的最终预测性能中,其实这个问题大家已经总结得挺详细了,例如:

https://blog.csdn.net/liuweiyuxiang/article/details/84668269

https://www.cnblogs.com/Archer-Fang/p/10651029.html

但是正如前面一个博文里网友评论的那样,似乎拿行、列来区分不太妥当。当然我也没有想到更好的办法来总结,根据例子就很快可以掌握了:

1. torch.max(a)是返回a中的最大值:

a=torch.tensor([[-2.1456, -0.6380,  1.3625],
                [-1.0394, -0.9641, -0.3667]])

print(torch.max(a))将得到:

tensor(1.3625) 

(另外,怎么把这个转成int呢?加上.item()即可)

另外torch.max(a)==a.max()

2. torch.max(a,1)返回的是每一行的最大值,还有最大值所在的索引:

torch.return_types.max(
values=tensor([ 1.3625, -0.3667]),
indices=tensor([2, 2]))

当然我们很多情况下只关心index(例如计算accuracy的时候),那么这时候用

torch.max(a,1)[1] 或者 a.max(1)[1] 取出来即可:

tensor([2, 2])

再加上.numpy()就可以转成array:

print(a.max(1)[1].numpy())得到:

[2 2]

基本的使用方法就是这些,但是有一个问题,为什么

torch.max(a,1)是每一行的最大值而torch.max(a,0)是每一列的最大值呢?

例如上面这个例子,print(torch.max(a,0))的输出是:

torch.return_types.max(
values=tensor([-1.0394, -0.6380,  1.3625]),
indices=tensor([1, 0, 0]))

实际上可以这样理解:0指的是在dimension 0中,各个vector之间比较,取到vector每一维的最大值。1指的是dimension 1中,每个逗号之间的元素进行比较,例如在[-2.1456, -0.6380,  1.3625]几个数中间进行比较,得到的最大值就是一个标量了,然后这些最大值拼接成一个vector。

所以可以思考一下这种情况(虽然遇到的很少,所以其实我们可以按照0,1分别对应行和列来理解):

a=torch.tensor([[[-0.2389, -0.8487, -1.5907,  0.0732],
                 [-0.2159,  1.1064, -1.1317,  0.6457],
                 [ 0.8191,  1.0146,  1.0241,  0.7042]],
                 [[-0.8285,  0.3628,  1.4678,  0.7984],
                  [ 0.1009, -0.3307, -0.8245,  0.0044],
                  [-1.5041,  0.5067,  0.4085,  0.2126]]])

在这个时候,print(torch.max(a,0))得到的结果是:

values=tensor([[-0.2389,  0.3628,  1.4678,  0.7984],
        [ 0.1009,  1.1064, -0.8245,  0.6457],
        [ 0.8191,  1.0146,  1.0241,  0.7042]]),
indices=tensor([[0, 1, 1, 1],
        [1, 0, 1, 0],
        [0, 0, 0, 0]]))

大家可以看看是不是这么回事。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章