Pytorch基礎函數(三)基本數學函數(1)——max()與min()函數(對dim參數的詳解)

這篇博客介紹兩個基本的數學函數,max()與min()兩個。從函數名就可以看出這兩個函數的作用:

max():查找最大值

min():查找最小值

筆者一開始遇到這兩個函數,感覺這兩個簡單函數對Tensor(張量)的操作還是有一些彎彎繞。我們來梳理一下,梳理出來就會清楚很多。

一、max()函數

函數定義:torch.max(input, dim, max=None, max_indices=None,keepdim=False)

參數:input:進行max操作的Tensor變量

           dim:需要查找最大值得維度(這裏很迷,後面重點介紹)

           max:結果張量,用於存儲查找到的最大值

           max_indices:結果張量,用於存儲查找到最大值所處的索引

           keepdim=False:返回值與原Tensor的size保持一致

1. 簡單應用

以一個一維Tensor爲例,解釋max()函數的輸入、輸出。

t1=torch.LongTensor([3,9,6,2,5])
print("-------max-------")
print(torch.max(t1))
print("-------max dim-------")
print(torch.max(t1,dim=0))

輸出結果爲:

可以看到,加了dim參數後,返回值中多了一個indices Tensor,這個張量用於存儲下最大值的下標,例子中最大值9的下標爲1。

 

2. 二維Tensor

對二維Tensor使用max/min函數,必須搞清楚的就是dim參數,先說結論:

①. dim爲0,用於查找每的最大值。返回下標索引。

②. dim爲1,用於查找每的最大值。返回下標索引。

③. 不添加dim參數,返回所有值中的最大值,且無索引。這裏放在4.中展示。

從這裏看就有些奇怪了,因爲衆所周知,二維情況下,第0維爲行,第1維爲列。爲什麼dim爲0時返回每列的最大值。

先看一個例子,以一個兩行三列的Tensor(size=2x3)維例:

t=torch.randn(2,3)
print(t)
print("-------max dim=0 -------")
print(torch.max(t,dim=0))
print("-------max dim=1 -------")
print(torch.max(t,dim=1))

輸出結果爲: 

當dim=0時,輸出最大值爲,第一列最大值0.6301,第二列最大值0.8937,第三列最大值0.3851。

當dim=1時,輸出最大值爲,第一行最大值0.8937,第二行最大值0.6301。

 

我們結論是正確的,我們從下標來分析一下:

首先,當dim=0時,三個列最大值的下標,分別爲[1][0]、[0][1]、[1][2]。(以及返回的索引張量[1,0,1])

           當dim=1時,兩個行最大值的下標,分別爲[0][1]、[1][0]。(以及返回的索引張量[1,0])

我們能夠看到,max()得到的最大值,本質上,是除了dim維以外,取其餘維度逐一遍歷分組(紅色下標),組內補上每一個dim維後的幾個數據的內部比較。

對dim參數的結論

           在其他維度均確定的情況下,比較所有dim維對應的數據,找到其中的最大值,並返回索引。

我們根據此例進行分析:

當dim=0時,除了dim等於的第0維,還有第1維,遍歷第1維,得到[0],[1],[2]。再補上第0維,根據遍歷第1維得到的,三個1維下標分爲三組。第一組([0][0],[1][0])、第二組([0][1],[1][1])、第三組([0][2],[1][2])。進行內部比較,得到三個組內最大值,即[0.6301,0.8937,0.3851],得到索引[1,0,1]。所以,也就是每一列的最大值了。

同理可以分析該例子中,dim=1的情況。

但是對於二維Tensor來說,記住結論比理解這個更容易。當三維及以上時,理解 這個就變得很重要了。

3、二維以上Tensor使用

這裏主要使用病分析一個,三維的Tensor使用max操作來驗證我們上面的結論。

例子:

t=torch.randn(2,2,2)
print(t)
print("-------max dim=0 -------")
print(torch.max(t,dim=0))
print("-------max dim=1 -------")
print(torch.max(t,dim=1))

輸出結果:

分析:

      ①. 對於dim=0,遍歷除了第0維外,得到的第1維、第2維組合有[0][0],[0][1],[1][0],[1][1]。所以分爲的組有第一組([0][0][0],[1][0][0]),第二組([0][0][1],[1][0][1]),第三組([0][1][0],[1][1][0]),第四組([0][1][1],[1][1][1])。

      數據就是(0.9560,0.0632),(1.6869,0.3790),(1.1282,0.8084),(0.8298,-1.4528)

      max結果得到:最大值[0.9560,1.6869,1.1282,0.8298],索引[0,0,0,0]

      ②. 對於dim-1,遍歷除了第1維外,得到的第0維、第2維組合有[0]_[0],[0]_[1],[1]_[0],[1]_[1]。將第1維加入後,既可以分爲([0][0][0],[0][1][0]),([0][0][1],[0][1][1]),([1][0][0],[1][1][0]),([1][0][1],[1][1][1]).

      得到的數據組爲(0.9560,1.1282),(1.6869,0.8298),(0.0632,0.8084),(0.3790,-1.4528)

      max結果得到: 最大值[1.1282,1.6869,0.8084,0.3790],索引[1,0,1,0]

4. 無dim參數的max()函數

當使用torch.max()函數時,不添加dim函數,則返回所有元素中值最大值(格式爲size爲1Tensor),且無索引。

例子:

t=torch.randn(2,2,2)
print(t)
print(torch.max(t))

輸出結果:

結果輸出,所有元素中的最大值。

二、min()函數

與max相同,但是返回爲最小值。

-------end------

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章