Pytorch(二):張量的基本操作:拼接,切分,索引,變換

       目錄

1.拼接

torch.cat()

torch.stack()

2.切分:

torch.chunk()

torch.split()

3.索引

torch.index_select()

torch.masked_select()

torch.ge(),gt(),le(),lt()

 4.變換:

torch.reshape()

torch.transpose()

torch.t()

torch.squeeze()

torch.unsqueeze()


1.拼接

torch.cat()

  • 聲明
torch.cat(tensors, dim=0, out=None) → Tensor
  • 功能:將張量按維度dim進行拼接

torch.stack()

  • 聲明
torch.stack(tensors, dim=0, out=None) → Tensor
  • 功能:在新創建的維度dim進行拼接

注意:cat()不會擴展張量的維度,而stack()會拓展張量的維度。

 

 

2.切分:

torch.chunk()

  • 聲明:
torch.chunk(input, chunks, dim=0) → List of Tensors
  • 功能:將張量按維度dim進行平均切分,返回張量列表。
  • 成員變量:
  1. input:要切分的張量
  2. chunks:要切分的份數
  3. dim:要切分的維度

注意:若不能整除,最後一項張量小於其他張量 

torch.split()

  • 聲明:
torch.split(tensor, split_size_or_sections, dim=0)
  • 功能:將張量按維度dim進行切分,返回張量列表。
  • 成員變量:
  1. tensor:要切分的張量
  2. split_size_or_sections:爲int時,表示每一份的長度;爲list時,按list元素切分
  3. dim:要切分的維度

注意:如果使用list作爲參數,則元素總和等於切分前的數量。

3.索引

torch.index_select()

  • 聲明:
torch.index_select(input, dim, index, out=None) → Tensor
  • 功能:在維度dim上,按index索引數據,返回依index索引數據拼接的張量
  • 成員變量:
  1. input:要索引的張量
  2. dim:要索引的維度
  3. index:要索引數據的序號(LongTensor

torch.masked_select()

  • 聲明:
torch.masked_select(input, mask, out=None) → Tensor
  • 功能:按mask中的true進行索引,返回一維張量
  • 成員變量:
  1. input:要索引的張量
  2. mask:與input同形狀的布爾類型張量(ByteTensor

torch.ge(),gt(),le(),lt()

  • 聲明:
torch.ge(input, other, out=None) → Tensor
torch.gt(input, other, out=None) → Tensor
torch.le(input, other, out=None) → Tensor
torch.lt(input, other, out=None) → Tensor
  •  功能:生成一個 input >= other,input > other,input <= other,input < other的bool Tensor.
  • 成員變量:
  1. input (Tensor) – the tensor to compare

  2. other (Tensor or python:float) – the tensor or value to compare

 4.變換:

torch.reshape()

  • 聲明:
torch.reshape(input, shape) → Tensor

 

  • 功能:變換張量形狀

注意:當張量在內存中是連續時,新張量與input共享數據內存

  • 成員變量:
  1. input:要變換的張量
  2. shape:新張量的形狀(tuple of python)

注意:shape=(-1,2)中的-1表示我們不關心的維度,由系統決定。

torch.transpose()

  • 聲明:
torch.transpose(input, dim0, dim1) → Tensor

 

  • 功能:交換張量的兩個維度
  • 成員變量:
  1. input:要變換的張量
  2. dim0:要交換的維度
  3. dim1:要交換的維度

 

torch.t()

  • 聲明:
torch.t(input) → Tensor
  • 功能:2維張量裝置,對矩陣而言等價於torch.transpose(input,0,1)
  • 成員變量:
  1. input:要變換的張量

torch.squeeze()

  • 聲明:
torch.squeeze(input, dim=None, out=None) → Tensor
  • 功能:壓縮長度爲1的維度(軸)
  • 成員變量:
  1. dim:若爲None,移除所有長度1的軸;若指定維度,當且僅當該軸長度爲1時,可以被移除;

torch.unsqueeze()

  • 聲明:
unsqueeze_(dim) → Tensor
  • 功能:y依據dim擴展維度
  • 成員變量:
  1. dim:擴展的維度

參考:

https://pytorch.org/docs/stable/tensors.html?highlight=unsque#torch.Tensor.unsqueeze

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章