pytorch常用tensor操作合集(轉)

目錄

gather

squeeze 

expand

sum

contiguous

softmax

max

argmax


gather

torch.gather(input,dim,index,out=None)。對指定維進行索引。比如4*3的張量,對dim=1進行索引,那麼index的取值範圍就是0~2.

input是一個張量,index是索引張量。input和index的size要麼全部維度都相同,要麼指定的dim那一維度值不同。輸出爲和index大小相同的張量。


import torch
a=torch.tensor([[.1,.2,.3],
                [1.1,1.2,1.3],
                [2.1,2.2,2.3],
                [3.1,3.2,3.3]])
b=torch.LongTensor([[1,2,1],
                    [2,2,2],
                    [2,2,2],
                    [1,1,0]])
b=b.view(4,3)
 
print(a.gather(1,b))
print(a.gather(0,b))
c=torch.LongTensor([1,2,0,1])
c=c.view(4,1)
print(a.gather(1,c))

輸出:

tensor([[ 0.2000,  0.3000,  0.2000],
        [ 1.3000,  1.3000,  1.3000],
        [ 2.3000,  2.3000,  2.3000],
        [ 3.2000,  3.2000,  3.1000]])
tensor([[ 1.1000,  2.2000,  1.3000],
        [ 2.1000,  2.2000,  2.3000],
        [ 2.1000,  2.2000,  2.3000],
        [ 1.1000,  1.2000,  0.3000]])
tensor([[ 0.2000],
        [ 1.3000],
        [ 2.1000],
        [ 3.2000]])

squeeze 

將維度爲1的壓縮掉。如size爲(3,1,1,2),壓縮之後爲(3,2)

    import torch
    a=torch.randn(2,1,1,3)
    print(a)
    print(a.squeeze())
 

輸出:

tensor([[[[-0.2320,  0.9513,  1.1613]]],
 
 
        [[[ 0.0901,  0.9613, -0.9344]]]])
tensor([[-0.2320,  0.9513,  1.1613],
        [ 0.0901,  0.9613, -0.9344]])

expand

擴展某個size爲1的維度。如(2,2,1)擴展爲(2,2,3)

import torch
x=torch.randn(2,2,1)
print(x)
y=x.expand(2,2,3)
print(y)

輸出:

tensor([[[ 0.0608],
         [ 2.2106]],
 
        [[-1.9287],
         [ 0.8748]]])
tensor([[[ 0.0608,  0.0608,  0.0608],
         [ 2.2106,  2.2106,  2.2106]],
 
        [[-1.9287, -1.9287, -1.9287],
         [ 0.8748,  0.8748,  0.8748]]])

sum

size爲(m,n,d)的張量,dim=1時,輸出爲size爲(m,d)的張量

import torch
a=torch.tensor([[[1,2,3],[4,8,12]],[[1,2,3],[4,8,12]]])
print(a.sum())
print(a.sum(dim=1))

輸出:

tensor(60)
tensor([[  5,  10,  15],
        [  5,  10,  15]])

contiguous

返回一個內存爲連續的張量,如本身就是連續的,返回它自己。一般用在view()函數之前,因爲view()要求調用張量是連續的。可以通過is_contiguous查看張量內存是否連續。

import torch
a=torch.tensor([[[1,2,3],[4,8,12]],[[1,2,3],[4,8,12]]])
print(a.is_contiguous)
 
print(a.contiguous().view(4,3))

輸出:

<built-in method is_contiguous of Tensor object at 0x7f4b5e35afa0>
tensor([[  1,   2,   3],
        [  4,   8,  12],
        [  1,   2,   3],
        [  4,   8,  12]])
        

softmax

假設數組V有C個元素。對其進行softmax等價於將V的每個元素的指數除以所有元素的指數之和。這會使值落在區間(0,1)上,並且和爲1。

S_{i}=\frac{e^{v_{i}}}{ \sum_{i=1}^{C} e^{v_{i}} }

import torch
import torch.nn.functional as F
 
a=torch.tensor([[1.,1],[2,1],[3,1],[1,2],[1,3]])
b=F.softmax(a,dim=1)
print(b)

輸出:

tensor([[ 0.5000,  0.5000],
        [ 0.7311,  0.2689],
        [ 0.8808,  0.1192],
        [ 0.2689,  0.7311],
        [ 0.1192,  0.8808]])
      

max

返回最大值,或指定維度的最大值以及index

import torch
a=torch.tensor([[.1,.2,.3],
                [1.1,1.2,1.3],
                [2.1,2.2,2.3],
                [3.1,3.2,3.3]])
print(a.max(dim=1))
print(a.max())

輸出:

    (tensor([ 0.3000,  1.3000,  2.3000,  3.3000]), tensor([ 2,  2,  2,  2]))
    tensor(3.3000)

argmax

返回最大值的index

    import torch
    a=torch.tensor([[.1,.2,.3],
                    [1.1,1.2,1.3],
                    [2.1,2.2,2.3],
                    [3.1,3.2,3.3]])
    print(a.argmax(dim=1))
    print(a.argmax())
  

輸出:

tensor([ 2,  2,  2,  2])
tensor(11)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章