pytorch中常用的函數

本文整理了筆者在學習pytorch中經常遇到的一些函數,本篇博客會不斷進行更新,並且會加上自己使用背景和使用經驗。

1. torch.max()函數

筆者最近在學習目標檢測的相關知識,無論是在計算多個bounding box之間的IOU還是確定bounding box的類別信息的時候,都會用到torch.max()函數。torch.max()可以得到一個tensor某個維度的最大值,可以的得到兩個tensor之間對應元素之間的最大值。

這個函數的簽名是這樣的:

torch.max(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)

1.1 單個tensor

情況1:如果輸入tensor是一維

    # 輸入一維的tensor
    a = torch.tensor([9,4,6,3,1,2,5,9])
    print(torch.max(a))

輸出結果就是這一維數據中,最大的元素

    tensor(9)

情況2:如果輸入tensor是二維

這時候就涉及到要比較哪個維度上的數據,是按行比較呢?還是按列比較呢?還是就是求這個tensor中最大的哪個元素?

(1)按列比較

    # 求每一列最大的元素
    b = torch.tensor([[4,8,2],[2,9,2]])
    max_value, max_index = torch.max(b,dim=0)
    print(max_value)
    print(max_index)

max_value表示得到的最終數據,max_index表示得到最大元素的每一列中索引。最終的數據結果是這樣的:

    # max_value: 表示每一列的最大值
    tensor([4, 9, 2])

    # max_index: 表示最大值在每一列中的索引
    tensor([0, 1, 1])

我們發現,最終的結果相對於輸入來說,結果的維度比輸入的少了一個維度,這是因爲我們指定了要比較哪個維度上面的數據。如果通過一幅圖片顯示torch.max()函數的效果就是這樣:

(2)按行比較

    # 求每一列最大的元素
    b = torch.tensor([[4,8,2],[2,9,2]])
    max_value, max_index = torch.max(b,dim=1)
    print(max_value)
    print(max_index)

最終的結果就是這樣的

    # max_value: 表示每一行的最大值
    tensor([8, 9])

    # max_index: 表示最大值在每一行中的索引
    tensor([1, 1])

(3)求整個tensor中最大的元素

    # 求一個tensor中最大的元素
    b = torch.tensor([[4,8,2],[2,9,2]])
    
    # 輸出結果
    print(torch.max(b))
    
    # tensor([9])

 

1.2 兩個tensor

torch.max()還可以比較兩個tensor之間的最大值

情況1:單元素和另外一個tensor進行比較

    c = torch.tensor([5])
    d = torch.tensor([1,2,3])
    print(torch.max(c,d))

輸出的結果是:

    tensor([5, 5, 5])

 

情況2:兩個不同tensor之間進行比較

這種情況下,必須要求兩個tensor的尺寸是一樣的,不然會報錯

    c = torch.tensor([[1,2,3],[2,3,4]])
    d = torch.tensor([[4,3,1],[3,6,3]])
    print(torch.max(c,d))

輸出結果是這樣:

    # 輸出結果
    tensor([[4, 3, 3],
            [3, 6, 4]])

 

2. torch.argmax()函數

torch.argmax()函數和torch.max()函數功能差不多,只不多前者只會返回一個tensor最大元素的索引,並不會返回這個最大元素是什麼。torch.max()函數不僅會返回最大值的索引,而且還會返回這個最大的元素是什麼。

max()函數的函數簽名是這樣的:

    torch.argmax(input, dim=None, keepdim=False) -> Tensor

 其中dim的含義是這樣的:the dimension to reduce(參考了這篇博客)

情況1:返回一個tensor中最大元素的索引

    # 返回一個tensor最大值的索引
    a = torch.tensor([[1,2,3],[2,3,4]])
    print(torch.argmax(b))

    # 輸出值
    # tensor(5)

情況2:返回指定維度的最大元素索引

    # 將第0維的數據進行比較
    b = torch.tensor([[1,2,3],[2,3,4]])
    print(torch.argmax(b,dim=0)

    # 輸入結果
    # tensor([1, 1, 1])


    # 將第1維數據進行比較
    print(torch.argmax(b,dim=1))

    # 輸出結果
    # tensor([2,2])

 

3. torch.nonzero()函數

torch.nonzero()函數有很多功能,比如能夠配合mask篩選出你想要的數據。比如在目標檢測求confidence小於threshold的bbox,或者只選擇某一類別的bbox,都可以使用nonzero()函數進行篩選。

torch.nonzero()函數簽名如下:

    torch.nonzero(input, out=None) -> Tensor

這個函數的主要就是找到tensor中所有不爲零元素的索引。函數的返回值是一個z*n的tensor。z表示不爲零的元素的個數,n表示輸入tensor的維度。

    # 打印a中所有不爲0的元素的索引
    a = torch.tensor([[1,2,3],[1,2,0]])
    index = torch.nonzero(a)
    print(index)

輸出結果是這樣的:

    tensor([[0, 0],
            [0, 1],
            [0, 2],
            [1, 0],
            [1, 1]])

但是,我目前主要用的就是確定某一列中不含0的索引,然後再將其對應的某些行全部過濾出來。

    # 第2列中有一個非零數據,我們想把非零列所對應的行過濾出來
    a = torch.tensor([[1,2,3],[1,2,0]])
    index = torch.nonzero(a[:,2])
    print(a[index.squeeze()])

    # 輸出效果
    tensor([1, 2, 3])

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章