pytorch矩陣變換相關筆記

pytorch的通道順序和caffe一致,是(B C H W)

 

下面函數例子均基於一個feature map: feat, 通道順序爲(B C H W)

 

1.permute

feat.permute(0,2,3,1),將原來序號爲0,1,2,3的通道調整爲0,2,3,1,對於feat來說這個操作就是將通道順序變爲tensorflow風格。

 

2.squeeze和unsqueeze

squeeze(dim)是對維度爲1的維度dim進行降維,如[3,2,1] -> [3,2]。

unsqueeze(dim)是在dim處插入維度爲1的維度。

import torch

a = torch.Tensor([[[1],[3]],[[5],[7]],[[9],[11]]])
a1 = a.squeeze(2)
a2 = a1.unsqueeze(0)

if __name__ == '__main__':
    print(a.size())
    print(a1.size())
    print(a2.size())

 輸出爲:

torch.Size([3, 2, 1])
torch.Size([3, 2])
torch.Size([1, 3, 2])

 

3.expand,repeat

首先都是用於一個維度的複製擴展。expand不佔用新空間,repeat要用。

import torch

a = torch.Tensor([[[1],[3]],[[5],[7]],[[9],[11]]])
a1 = a.expand([3,2,4])
a2 = a.repeat([3,2,4])
if __name__ == '__main__':
    print(a.size())
    print(a1)
    print(a2)

輸出爲:

torch.Size([3, 2, 1])
tensor([[[  1.,   1.,   1.,   1.],
         [  3.,   3.,   3.,   3.]],

        [[  5.,   5.,   5.,   5.],
         [  7.,   7.,   7.,   7.]],

        [[  9.,   9.,   9.,   9.],
         [ 11.,  11.,  11.,  11.]]])
tensor([[[  1.,   1.,   1.,   1.],
         [  3.,   3.,   3.,   3.],
         [  1.,   1.,   1.,   1.],
         [  3.,   3.,   3.,   3.]],

        [[  5.,   5.,   5.,   5.],
         [  7.,   7.,   7.,   7.],
         [  5.,   5.,   5.,   5.],
         [  7.,   7.,   7.,   7.]],

        [[  9.,   9.,   9.,   9.],
         [ 11.,  11.,  11.,  11.],
         [  9.,   9.,   9.,   9.],
         [ 11.,  11.,  11.,  11.]],

        [[  1.,   1.,   1.,   1.],
         [  3.,   3.,   3.,   3.],
         [  1.,   1.,   1.,   1.],
         [  3.,   3.,   3.,   3.]],

        [[  5.,   5.,   5.,   5.],
         [  7.,   7.,   7.,   7.],
         [  5.,   5.,   5.,   5.],
         [  7.,   7.,   7.,   7.]],

        [[  9.,   9.,   9.,   9.],
         [ 11.,  11.,  11.,  11.],
         [  9.,   9.,   9.,   9.],
         [ 11.,  11.,  11.,  11.]],

        [[  1.,   1.,   1.,   1.],
         [  3.,   3.,   3.,   3.],
         [  1.,   1.,   1.,   1.],
         [  3.,   3.,   3.,   3.]],

        [[  5.,   5.,   5.,   5.],
         [  7.,   7.,   7.,   7.],
         [  5.,   5.,   5.,   5.],
         [  7.,   7.,   7.,   7.]],

        [[  9.,   9.,   9.,   9.],
         [ 11.,  11.,  11.,  11.],
         [  9.,   9.,   9.,   9.],
         [ 11.,  11.,  11.,  11.]]])

 

4.view

view是一個很有意思的函數,總的來說,就是把整個高維數組按照數字出現的順序展開成一維,然後在不改變數字順序的情況下,按照view給出的shape生成新的數組。

import torch

a = torch.Tensor([[[1,2],[3,4]],[[5,6],[7,8]],[[9,10],[11,12]]])
a1 = a.view([4,3,1])
a2 = a.view([2,6])
if __name__ == '__main__':
    print(a.size())
    print(a1)
    print(a2)

結果爲:

torch.Size([3, 2, 2])
tensor([[[  1.],
         [  2.],
         [  3.]],

        [[  4.],
         [  5.],
         [  6.]],

        [[  7.],
         [  8.],
         [  9.]],

        [[ 10.],
         [ 11.],
         [ 12.]]])
tensor([[  1.,   2.,   3.,   4.,   5.,   6.],
        [  7.,   8.,   9.,  10.,  11.,  12.]])

 

5.gather

由淺入深,先引用一個@江戶川柯壯 的例子

b = torch.Tensor([[1,2,3],[4,5,6]])
print b
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print torch.gather(b, dim=1, index=index_1)
print torch.gather(b, dim=0, index=index_2)
--------------------- 
作者:江戶川柯壯 
來源:CSDN 
原文:https://blog.csdn.net/edogawachia/article/details/80515038 

輸出分別爲:

 1  2  3
 4  5  6
[torch.FloatTensor of size 2x3]


 1  2
 6  4
[torch.FloatTensor of size 2x2]


 1  5  6
 1  2  3
[torch.FloatTensor of size 2x3]

dim是最難理解的一個變量。這裏我的理解是,dim是index中定義序號的維度。dim=0即序號是沿着縱向數的,橫向同一行的數字序號一樣,dim=1則反過來。

index的維度也需要注意。首先index每次是要把其他維遍歷一遍的,即index中除了dim維,其他維的形狀要和原數組保持一致。輸入是[2,3], dim=0時,index可能的維數爲[N,3],即可以遍歷數組的列多次,但是每一次必須把所有列都遍歷一遍。其次,index的維數要和輸入數組一致,即使只遍歷一次,二維輸入對應的index也要是二維的[1,3],而不能是一維的3.

再舉一個3維的例子:

import torch

a = torch.Tensor([[[1,2],[3,4]],[[5,6],[7,8]],[[9,0],[11,12]]])
index = torch.LongTensor([[[1,0]],[[0,0]],[[1,1]]])


if __name__ == '__main__':
    print(a.size())
    print(index.size())
    print(torch.gather(a,1,index))

輸出爲:

torch.Size([3, 2, 2])
torch.Size([3, 1, 2])
tensor([[[  3.,   2.]],

        [[  5.,   6.]],

        [[ 11.,  12.]]])

熟記除了dim維其他維度必須與原矩陣保持一致,最終的效果就一目瞭然了。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章