注:以下均基於導入庫
import torch
-
構造方法
常用函數
-
.reshape()#變形
-
.size()#返回大小
-
.dim()#返回維度,即size返回的數據條目個數
-
.numel()#返回元素個數,即size返回條目之積
-
.dtype#查看元素類型,int默認torch.int64,bool默認torch.uint8,float默認touch.float32
-
常用計算
-
.repeat() 重複
-
增減維度unsqueeze 、squeeze(不改變元素位置)
a=torch.Tensor([[1,2,3],[4,5,6]])
print(a.size())#torch.Size([2, 3])
b=a.unsqueeze(2)#括號裏爲維度增加的位置
print(b.size())#torch.Size([2, 3, 1])
c=b.squeeze(-1)#去維度
print(c.size())#torch.Size([2, 3])
-
.sequeeze()可消除張量大小中大小爲1的維度
-
.unsqueeze()可增加以下大小爲0的維度
-
permute()/transpose()/t()張量的交換、張量的轉置、二維張量轉置
-
Tensor連接cat
a=torch.Tensor([[1,2,3],[4,5,6]])
d=torch.ones([2,3])
e=torch.cat([a,d],1)
print(e)
print(torch.cat([a **i for i in range(3)],1))
#result:
# tensor([[1., 2., 3., 1., 1., 1.],
# [4., 5., 6., 1., 1., 1.]])
# tensor([[ 1., 1., 1., 1., 2., 3., 1., 4., 9.],
# [ 1., 1., 1., 4., 5., 6., 16., 25., 36.]])
-
大於某個值ge
a=torch.Tensor([[1,2,3],[4,5,6]])
b=torch.Tensor([0.3,0.40,0.6,0.8,0.1])
print(b.ge(0.5))#tensor([0, 0, 1, 1, 0], dtype=torch.uint8)
print(b.ge(0.5).float())#tensor([0., 0., 1., 1., 0.])
cor=((b.ge(0.5).float()==torch.ones(5)).sum())
print(cor)#tensor(2)
print(b.size(),b.size(0))#torch.Size([5]) 5