本文整理了筆者在學習pytorch中經常遇到的一些函數,本篇博客會不斷進行更新,並且會加上自己使用背景和使用經驗。
1. torch.max()函數
筆者最近在學習目標檢測的相關知識,無論是在計算多個bounding box之間的IOU還是確定bounding box的類別信息的時候,都會用到torch.max()函數。torch.max()可以得到一個tensor某個維度的最大值,可以的得到兩個tensor之間對應元素之間的最大值。
這個函數的簽名是這樣的:
torch.max(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)
1.1 單個tensor
情況1:如果輸入tensor是一維:
# 輸入一維的tensor
a = torch.tensor([9,4,6,3,1,2,5,9])
print(torch.max(a))
輸出結果就是這一維數據中,最大的元素
tensor(9)
情況2:如果輸入tensor是二維:
這時候就涉及到要比較哪個維度上的數據,是按行比較呢?還是按列比較呢?還是就是求這個tensor中最大的哪個元素?
(1)按列比較
# 求每一列最大的元素
b = torch.tensor([[4,8,2],[2,9,2]])
max_value, max_index = torch.max(b,dim=0)
print(max_value)
print(max_index)
max_value表示得到的最終數據,max_index表示得到最大元素的每一列中索引。最終的數據結果是這樣的:
# max_value: 表示每一列的最大值
tensor([4, 9, 2])
# max_index: 表示最大值在每一列中的索引
tensor([0, 1, 1])
我們發現,最終的結果相對於輸入來說,結果的維度比輸入的少了一個維度,這是因爲我們指定了要比較哪個維度上面的數據。如果通過一幅圖片顯示torch.max()函數的效果就是這樣:
(2)按行比較
# 求每一列最大的元素
b = torch.tensor([[4,8,2],[2,9,2]])
max_value, max_index = torch.max(b,dim=1)
print(max_value)
print(max_index)
最終的結果就是這樣的
# max_value: 表示每一行的最大值
tensor([8, 9])
# max_index: 表示最大值在每一行中的索引
tensor([1, 1])
(3)求整個tensor中最大的元素
# 求一個tensor中最大的元素
b = torch.tensor([[4,8,2],[2,9,2]])
# 輸出結果
print(torch.max(b))
# tensor([9])
1.2 兩個tensor
torch.max()還可以比較兩個tensor之間的最大值
情況1:單元素和另外一個tensor進行比較
c = torch.tensor([5])
d = torch.tensor([1,2,3])
print(torch.max(c,d))
輸出的結果是:
tensor([5, 5, 5])
情況2:兩個不同tensor之間進行比較
這種情況下,必須要求兩個tensor的尺寸是一樣的,不然會報錯
c = torch.tensor([[1,2,3],[2,3,4]])
d = torch.tensor([[4,3,1],[3,6,3]])
print(torch.max(c,d))
輸出結果是這樣:
# 輸出結果
tensor([[4, 3, 3],
[3, 6, 4]])
2. torch.argmax()函數
torch.argmax()函數和torch.max()函數功能差不多,只不多前者只會返回一個tensor最大元素的索引,並不會返回這個最大元素是什麼。torch.max()函數不僅會返回最大值的索引,而且還會返回這個最大的元素是什麼。
max()函數的函數簽名是這樣的:
torch.argmax(input, dim=None, keepdim=False) -> Tensor
其中dim的含義是這樣的:the dimension to reduce(參考了這篇博客)
情況1:返回一個tensor中最大元素的索引
# 返回一個tensor最大值的索引
a = torch.tensor([[1,2,3],[2,3,4]])
print(torch.argmax(b))
# 輸出值
# tensor(5)
情況2:返回指定維度的最大元素索引
# 將第0維的數據進行比較
b = torch.tensor([[1,2,3],[2,3,4]])
print(torch.argmax(b,dim=0)
# 輸入結果
# tensor([1, 1, 1])
# 將第1維數據進行比較
print(torch.argmax(b,dim=1))
# 輸出結果
# tensor([2,2])
3. torch.nonzero()函數
torch.nonzero()函數有很多功能,比如能夠配合mask篩選出你想要的數據。比如在目標檢測求confidence小於threshold的bbox,或者只選擇某一類別的bbox,都可以使用nonzero()函數進行篩選。
torch.nonzero()函數簽名如下:
torch.nonzero(input, out=None) -> Tensor
這個函數的主要就是找到tensor中所有不爲零元素的索引。函數的返回值是一個z*n的tensor。z表示不爲零的元素的個數,n表示輸入tensor的維度。
# 打印a中所有不爲0的元素的索引
a = torch.tensor([[1,2,3],[1,2,0]])
index = torch.nonzero(a)
print(index)
輸出結果是這樣的:
tensor([[0, 0],
[0, 1],
[0, 2],
[1, 0],
[1, 1]])
但是,我目前主要用的就是確定某一列中不含0的索引,然後再將其對應的某些行全部過濾出來。
# 第2列中有一個非零數據,我們想把非零列所對應的行過濾出來
a = torch.tensor([[1,2,3],[1,2,0]])
index = torch.nonzero(a[:,2])
print(a[index.squeeze()])
# 輸出效果
tensor([1, 2, 3])