Tensor Operation
1. 張量拼接與切分
torch.cat():將張量的維度dim進行拼接,不會擴張張量的維度
如dim=0,則兩個向量將在第0維進行拼接:(3,4)concat(3,4)-->(6,4)
torch.stack():在新創建的維度dim上進行拼接
如dim=0,則(3,4)stack(3,4)-->(2,3,4)
如dim=2,則(3,4)stack(3,4)-->(3,4,2)
torch.chunk():將張量維度dim進行平均切分,返回張量列表。
若不能整除,最後一份張量小於其他張量
chunk:要切的份數
torch.split():將張量按維度dim進行切分,返回張量列表
split_size_or_sections:當爲int時,表示每一份的長度;當爲list時,按list元素切分
2. 張量索引
torch.index_select():在維度dim上,按index索引數據。依index索引數據拼接的張量。
index:是dtype爲torch.long的tensor
t = torch.randint(0, 9, size=(3, 3))
idx = torch.tensor([0, 1], dtype=torch.long)
t_select = torch.index_select(t, dim=0, index=idx)
torch.mask_select():按mask中的True進行索引,返回一維張量。
t = torch.randint(0, 9, size=(3, 3))
mask = t.ge(5) # >=5 return true; else false
t_select = torch.masked_select(t, mask)
3. 張量變換
torch.reshape():變換張量形狀。當張量在內存中是連續時,新張量與input共享數據內存。
torch.transpose():變換張量的兩個維度
torch.t():二維張量的轉置
torch.squeeze():壓縮長度爲1的維度。
dim爲None時,移除所有長度爲1的軸;若指定維度,當且僅當該軸長度爲1時,可以被移除
torch.unqueeze():依據dim擴展維度
Tensor Math Operation
torch.add():逐元素計算input+alpha+other
torch.addcdiv():
torch.addcmul():