這篇博客介紹兩個基本的數學函數,max()與min()兩個。從函數名就可以看出這兩個函數的作用:
max():查找最大值
min():查找最小值
筆者一開始遇到這兩個函數,感覺這兩個簡單函數對Tensor(張量)的操作還是有一些彎彎繞。我們來梳理一下,梳理出來就會清楚很多。
一、max()函數
函數定義:torch.max(input, dim, max=None, max_indices=None,keepdim=False)
參數:input:進行max操作的Tensor變量
dim:需要查找最大值得維度(這裏很迷,後面重點介紹)
max:結果張量,用於存儲查找到的最大值
max_indices:結果張量,用於存儲查找到最大值所處的索引
keepdim=False:返回值與原Tensor的size保持一致
1. 簡單應用
以一個一維Tensor爲例,解釋max()函數的輸入、輸出。
t1=torch.LongTensor([3,9,6,2,5])
print("-------max-------")
print(torch.max(t1))
print("-------max dim-------")
print(torch.max(t1,dim=0))
輸出結果爲:
可以看到,加了dim參數後,返回值中多了一個indices Tensor,這個張量用於存儲下最大值的下標,例子中最大值9的下標爲1。
2. 二維Tensor
對二維Tensor使用max/min函數,必須搞清楚的就是dim參數,先說結論:
①. dim爲0,用於查找每列的最大值。返回行下標索引。
②. dim爲1,用於查找每行的最大值。返回列下標索引。
③. 不添加dim參數,返回所有值中的最大值,且無索引。這裏放在4.中展示。
從這裏看就有些奇怪了,因爲衆所周知,二維情況下,第0維爲行,第1維爲列。爲什麼dim爲0時返回每列的最大值。
先看一個例子,以一個兩行三列的Tensor(size=2x3)維例:
t=torch.randn(2,3)
print(t)
print("-------max dim=0 -------")
print(torch.max(t,dim=0))
print("-------max dim=1 -------")
print(torch.max(t,dim=1))
輸出結果爲:
當dim=0時,輸出最大值爲,第一列最大值0.6301,第二列最大值0.8937,第三列最大值0.3851。
當dim=1時,輸出最大值爲,第一行最大值0.8937,第二行最大值0.6301。
我們結論是正確的,我們從下標來分析一下:
首先,當dim=0時,三個列最大值的下標,分別爲[1][0]、[0][1]、[1][2]。(以及返回的索引張量[1,0,1])
當dim=1時,兩個行最大值的下標,分別爲[0][1]、[1][0]。(以及返回的索引張量[1,0])
我們能夠看到,max()得到的最大值,本質上,是除了dim維以外,取其餘維度逐一遍歷分組(紅色下標),組內補上每一個dim維後的幾個數據的內部比較。
對dim參數的結論:
在其他維度均確定的情況下,比較所有dim維對應的數據,找到其中的最大值,並返回索引。
我們根據此例進行分析:
當dim=0時,除了dim等於的第0維,還有第1維,遍歷第1維,得到[0],[1],[2]。再補上第0維,根據遍歷第1維得到的,三個1維下標分爲三組。第一組([0][0],[1][0])、第二組([0][1],[1][1])、第三組([0][2],[1][2])。進行內部比較,得到三個組內最大值,即[0.6301,0.8937,0.3851],得到索引[1,0,1]。所以,也就是每一列的最大值了。
同理可以分析該例子中,dim=1的情況。
但是對於二維Tensor來說,記住結論比理解這個更容易。當三維及以上時,理解 這個就變得很重要了。
3、二維以上Tensor使用
這裏主要使用病分析一個,三維的Tensor使用max操作來驗證我們上面的結論。
例子:
t=torch.randn(2,2,2)
print(t)
print("-------max dim=0 -------")
print(torch.max(t,dim=0))
print("-------max dim=1 -------")
print(torch.max(t,dim=1))
輸出結果:
分析:
①. 對於dim=0,遍歷除了第0維外,得到的第1維、第2維組合有[0][0],[0][1],[1][0],[1][1]。所以分爲的組有第一組([0][0][0],[1][0][0]),第二組([0][0][1],[1][0][1]),第三組([0][1][0],[1][1][0]),第四組([0][1][1],[1][1][1])。
數據就是(0.9560,0.0632),(1.6869,0.3790),(1.1282,0.8084),(0.8298,-1.4528)
max結果得到:最大值[0.9560,1.6869,1.1282,0.8298],索引[0,0,0,0]
②. 對於dim-1,遍歷除了第1維外,得到的第0維、第2維組合有[0]_[0],[0]_[1],[1]_[0],[1]_[1]。將第1維加入後,既可以分爲([0][0][0],[0][1][0]),([0][0][1],[0][1][1]),([1][0][0],[1][1][0]),([1][0][1],[1][1][1]).
得到的數據組爲(0.9560,1.1282),(1.6869,0.8298),(0.0632,0.8084),(0.3790,-1.4528)
max結果得到: 最大值[1.1282,1.6869,0.8084,0.3790],索引[1,0,1,0]
4. 無dim參數的max()函數
當使用torch.max()函數時,不添加dim函數,則返回所有元素中值最大值(格式爲size爲1Tensor),且無索引。
例子:
t=torch.randn(2,2,2)
print(t)
print(torch.max(t))
輸出結果:
結果輸出,所有元素中的最大值。
二、min()函數
與max相同,但是返回爲最小值。
-------end------