PyTorch函數解釋:cat、stack、transpose、permute、squeeze、unsqueeze

torch.cat() 張量拼接

對張量沿着某一維度進行拼接。連接後數據的總維數不變。,ps:能拼接的前提是對應的維度相同!!!

例如對兩個2維tensor(分別爲2*3,1*3)進行拼接,拼接完後變爲3*3的2維 tensor。

In [1]: import torch

In [2]: torch.manual_seed(1)
Out[2]: <torch._C.Generator at 0x19e56f02e50>

In [3]: x = torch.randn(2,3)

In [4]: y = torch.randn(1,3)

In [5]: x
Out[5]:
tensor([[ 0.6614,  0.2669,  0.0617],
        [ 0.6213, -0.4519, -0.1661]])

In [6]: y
Out[6]: tensor([[-1.5228,  0.3817, -1.0276]])

In [9]: torch.cat((x,y),0)
Out[9]:
tensor([[ 0.6614,  0.2669,  0.0617],
        [ 0.6213, -0.4519, -0.1661],
        [-1.5228,  0.3817, -1.0276]])

以上dim=0 表示按列進行拼接,dim=1表示按行進行拼接。

代碼如下:

In [11]: z = torch.randn(2,2)

In [12]: z
Out[12]:
tensor([[-0.5631, -0.8923],
        [-0.0583, -0.1955]])

In [13]: x
Out[13]:
tensor([[ 0.6614,  0.2669,  0.0617],
        [ 0.6213, -0.4519, -0.1661]])

In [14]: torch.cat((x,z),1)
Out[14]:
tensor([[ 0.6614,  0.2669,  0.0617, -0.5631, -0.8923],
        [ 0.6213, -0.4519, -0.1661, -0.0583, -0.1955]])

torch.stack() 張量堆疊

torch.cat()拼接不會增加新的維度,但torch.stack()則會增加新的維度。

例如對兩個1*2 維的 tensor 在第0個維度上stack,則會變爲2*1*2的 tensor;在第1個維度上stack,則會變爲1*2*2 的tensor。

In [22]: x = torch.randn(1,2)

In [23]: y = torch.randn(1,2)

In [24]: x.shape
Out[24]: torch.Size([1, 2])

In [25]: x = torch.randn(1,2)

In [26]: y = torch.randn(1,2)

In [27]: torch.stack((x,y),0) # 維度0堆疊
Out[27]:
tensor([[[-1.8313,  1.5987]],

        [[-1.2770,  0.3255]]])

In [28]: torch.stack((x,y),0).shape
Out[28]: torch.Size([2, 1, 2])

In [29]: torch.stack((x,y),1) # 維度1堆疊
Out[29]:
tensor([[[-1.8313,  1.5987],
         [-1.2770,  0.3255]]])

In [30]: torch.stack((x,y),1).shape
Out[30]: torch.Size([1, 2, 2])

torch.transpose() 矩陣轉置

舉例說明

torch.manual_seed(1)
x = torch.randn(2,3)
print(x)

原來x的結果:

 0.6614  0.2669  0.0617
 0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x3]

將x的維度互換:x.transpose(0,1) ,其實相當於轉置操作!
結果

0.6614  0.6213
 0.2669 -0.4519
 0.0617 -0.1661
[torch.FloatTensor of size 3x2]

torch.permute() 多維度互換

permute是更靈活的transpose,可以靈活的對原數據的維度進行調換,而數據本身不變。

In [31]: x = torch.randn(2,3,4)

In [32]: x
Out[32]:
tensor([[[ 0.7626,  0.4415,  1.1651,  2.0154],
         [ 0.2152, -0.5242, -1.8034, -1.3083],
         [ 0.4100,  0.4085,  0.2579,  1.0950]],

        [[-0.5065,  0.0998, -0.6540,  0.7317],
         [-1.4567,  1.6089,  0.0938, -1.2597],
         [ 0.2546, -0.5020, -1.0412,  0.7323]]])

In [33]: x.shape
Out[33]: torch.Size([2, 3, 4])

In [34]: x.permute(1,0,2) # 0維和1維互換,2維不變!
Out[34]:
tensor([[[ 0.7626,  0.4415,  1.1651,  2.0154],
         [-0.5065,  0.0998, -0.6540,  0.7317]],

        [[ 0.2152, -0.5242, -1.8034, -1.3083],
         [-1.4567,  1.6089,  0.0938, -1.2597]],

        [[ 0.4100,  0.4085,  0.2579,  1.0950],
         [ 0.2546, -0.5020, -1.0412,  0.7323]]])

In [35]: x.permute(1,0,2).shape
Out[35]: torch.Size([3, 2, 4])

torch.squeeze() 和 torch.unsqueeze()

常用來增加或減少維度,如沒有batch維度時,增加batch維度爲1。

  • squeeze(dim_n)壓縮,減少dim_n維度 ,即去掉元素數量爲1的dim_n維度。
  • unsqueeze(dim_n),增加dim_n維度,元素數量爲1。
In [38]: x = torch.randn(1,3,4)

In [39]: x.shape
Out[39]: torch.Size([1, 3, 4])

In [40]: x
Out[40]:
tensor([[[-0.4791,  0.2912, -0.8317, -0.5525],
         [ 0.6355, -0.3968, -0.6571, -1.6428],
         [ 0.9803, -0.0421, -0.8206,  0.3133]]])

In [41]: x.squeeze()
Out[41]:
tensor([[-0.4791,  0.2912, -0.8317, -0.5525],
        [ 0.6355, -0.3968, -0.6571, -1.6428],
        [ 0.9803, -0.0421, -0.8206,  0.3133]])

In [42]: x.squeeze().shape
Out[42]: torch.Size([3, 4])

In [43]: x.unsqueeze(0)
Out[43]:
tensor([[[[-0.4791,  0.2912, -0.8317, -0.5525],
          [ 0.6355, -0.3968, -0.6571, -1.6428],
          [ 0.9803, -0.0421, -0.8206,  0.3133]]]])

In [44]: x.unsqueeze(0).shape
Out[44]: torch.Size([1, 1, 3, 4])
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章