pytorch中torch.cat()函數理解

pytorch中torch.cat()函數:

功能:拼接兩個tensor。

用法:把兩個tensor A和B拼接在一起,可進行如下操作:

C = torch.cat( (A,B),0 )  #按維數0拼接(豎着拼)

C = torch.cat( (A,B),1 )  #按維數1拼接(橫着拼)

示例說明

1)按維數0拼接

>>> import torch

>>> A=torch.ones(2,3)    #2x3的張量(2行3列的矩陣)                                    

>>> A

tensor([[ 1.,  1.,  1.],

        [ 1.,  1.,  1.]])

>>> B=2*torch.ones(4,3)  #4x3的張量(4行3列的矩陣)                                   

>>> B

tensor([[ 2.,  2.,  2.],

        [ 2.,  2.,  2.],

        [ 2.,  2.,  2.],

        [ 2.,  2.,  2.]])

>>> C=torch.cat((A,B),0)  #按維數0(行)拼接

>>> C

tensor([[ 1.,  1.,  1.],

         [ 1.,  1.,  1.],

         [ 2.,  2.,  2.],

         [ 2.,  2.,  2.],

         [ 2.,  2.,  2.],

         [ 2.,  2.,  2.]])

>>> C.size()

torch.Size([6, 3])

 

2)按維數1拼接

>>> D=2*torch.ones(2,4) #2x4的張量(2行4列的矩陣)

>>> C=torch.cat((A,D),1)#按維數1(列)拼接

>>> C

tensor([[ 1.,  1.,  1.,  2.,  2.,  2.,  2.],

        [ 1.,  1.,  1.,  2.,  2.,  2.,  2.]])

>>> C.size()

torch.Size([2, 7])

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章