pytorch 張量的操作

Tensor Operation

1. 張量拼接與切分

torch.cat():將張量的維度dim進行拼接,不會擴張張量的維度

                    如dim=0,則兩個向量將在第0維進行拼接:(3,4)concat(3,4)-->(6,4)

torch.stack():在新創建的維度dim上進行拼接

                    如dim=0,則(3,4)stack(3,4)-->(2,3,4)

                    如dim=2,則(3,4)stack(3,4)-->(3,4,2)

torch.chunk():將張量維度dim進行平均切分,返回張量列表。

                         若不能整除,最後一份張量小於其他張量

   chunk:要切的份數

torch.split():將張量按維度dim進行切分,返回張量列表

split_size_or_sections:當爲int時,表示每一份的長度;當爲list時,按list元素切分

 

2. 張量索引

torch.index_select():在維度dim上,按index索引數據。依index索引數據拼接的張量。

index:是dtype爲torch.long的tensor

t = torch.randint(0, 9, size=(3, 3))
idx = torch.tensor([0, 1], dtype=torch.long)
t_select = torch.index_select(t, dim=0, index=idx)

torch.mask_select():按mask中的True進行索引,返回一維張量。

t = torch.randint(0, 9, size=(3, 3))
mask = t.ge(5)  #  >=5 return true; else false
t_select = torch.masked_select(t, mask)

 

3. 張量變換

torch.reshape():變換張量形狀。當張量在內存中是連續時,新張量與input共享數據內存。

torch.transpose():變換張量的兩個維度

torch.t():二維張量的轉置

torch.squeeze():壓縮長度爲1的維度。

  dim爲None時,移除所有長度爲1的軸;若指定維度,當且僅當該軸長度爲1時,可以被移除

torch.unqueeze():依據dim擴展維度

 

Tensor Math Operation

torch.add():逐元素計算input+alpha+other

torch.addcdiv():

                        

torch.addcmul():

                           

                         

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章