【pytorch】tensor的操作

在這裏插入圖片描述

轉置 transpose和permute

對二維矩陣來說,轉置就是把矩陣的行列互換 也就是原來是X[i][j]X[i][j]轉置後變成X[j][i]X[j][i]

torch.transpose

transpose函數介紹

torch.transpose(input, dim0, dim1) → Tensor

transpose就是把input 的第dim0維和dim1維進行交換

  • input (Tensor) – the input tensor.
  • dim0 (int) – the first dimension to be transposed
  • dim1 (int) – the second dimension to be transposed

transpose兩種調用方式

對x的第0維和第1維進行轉置
torch.transpose(x,0,1)x.transpose(0,1)

連續使用transpose實現高維矩陣轉置

transpose只能對兩個維度進行轉置,那麼高維矩陣的維度大於2,如果想進行高維矩陣的轉置,我們可以連續使用transpose,如下:
我們創建一個2x3x2x2的tensor,然後

import torch
x = torch.randn(2,3,2,2)  #創建一個2x3x2x2的tensor
y = x.transpose(1,2).transpose(2,3) #先對tensor的第一維和第二維轉置則tensor的shape變爲(2,2,3,2)
                                    #再對tensor的第二維和第三維轉置則tensor的shape變爲(2,2,2,3)
/*
轉置前x:
tensor([[[[ 0.7113,  1.0002],
          [ 0.1047, -1.8522]],

         [[ 0.8429, -0.5547],
          [ 0.3536,  0.1121]],

         [[ 0.9539,  0.7841],
          [-0.0667,  0.6173]]],


        [[[ 0.6009,  0.6388],
          [-0.0523, -0.5926]],

         [[ 0.9110,  1.8832],
          [-0.8734, -1.9924]],

         [[-1.2680, -0.3895],
          [ 0.1211, -0.7359]]]])
轉置後
tensor([[[[ 0.7113,  0.8429,  0.9539],
          [ 1.0002, -0.5547,  0.7841]],

         [[ 0.1047,  0.3536, -0.0667],
          [-1.8522,  0.1121,  0.6173]]],


        [[[ 0.6009,  0.9110, -1.2680],
          [ 0.6388,  1.8832, -0.3895]],

         [[-0.0523, -0.8734,  0.1211],
          [-0.5926, -1.9924, -0.7359]]]])
*/

torch.Tensor.permute

permute函數介紹

permute就是按給的維度的順序重新對tensor進行排列

permute(*dims) → Tensor

  • *dims (int…) – The desired ordering of dimensions 想要重新排列的維度的順序
>>> x = torch.randn(2, 3, 5)
>>> x.size()
torch.Size([2, 3, 5])
>>> x.permute(2, 0, 1).size()  
torch.Size([5, 2, 3])

torch.Tensor.view

view是用來改變tensor的形狀,也就是輸出指定形狀的tensor。

view(*shape) → Tensor

  • shape (torch.Size or int…) – the desired size
x = torch.randn(4, 4)
z = x.view(-1, 8) #表示此維度的大小取決於別的維度
/**
x
tensor([[ 0.6943,  0.1583, -0.4266,  0.6899],
        [ 0.4207,  1.0217,  0.2867, -1.0755],
        [ 1.7662,  0.3465, -1.0772, -1.4075],
        [-0.9731,  0.0983, -1.4642, -1.3546]])
**/
z
tensor([[ 0.6943,  0.1583, -0.4266,  0.6899,  0.4207,  1.0217,  0.2867, -1.0755],
        [ 1.7662,  0.3465, -1.0772, -1.4075, -0.9731,  0.0983, -1.4642, -1.3546]])

另外要注意的是view操作不改變memory中的tensor layout,如下例雖然b通過transpose得到的形狀和c通過view得到的形狀一樣,但實際上b和c是不等價的。

>>> a = torch.randn(1, 2, 3, 4)
>>> a.size()
torch.Size([1, 2, 3, 4])
>>> b = a.transpose(1, 2)  # Swaps 2nd and 3rd dimension
>>> b.size()
torch.Size([1, 3, 2, 4])
>>> c = a.view(1, 3, 2, 4)  # Does not change tensor layout in memory
>>> c.size()
torch.Size([1, 3, 2, 4])
>>> torch.equal(b, c)
False

torch.Tensor.repeat

repeat是在指定的維度重複數據

repeat(*sizes) → Tensor

  • sizes (torch.Size or int…) – The number of times to repeat this tensor along each dimension 在每一個維度要重複數據的次數
x = torch.tensor([1, 2, 3])
x.repeat(4,2) #第一個維度重複4次,第二個維度重複2次
/**
tensor([[1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3]])
**/

要想增加維度,我們可以直接如下操作

x.repeat(4,2,1) #相當於把tensor從原來的二維變成現在的三維
/**
tensor([[[1, 2, 3],
         [1, 2, 3]],

        [[1, 2, 3],
         [1, 2, 3]],

        [[1, 2, 3],
         [1, 2, 3]],

        [[1, 2, 3],
         [1, 2, 3]]])
**/

torch.cat

Concatenates the given sequence of seq tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty.
對相同大小的tensors在指定的dim進行concat

>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 0)
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 1)
tensor([[ 0.6580, -1.0969, -0.4614,  0.6580, -1.0969, -0.4614,  0.6580,
         -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497, -0.1034, -0.5790,  0.1497, -0.1034,
         -0.5790,  0.1497]])
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章