注:以下均基于导入库
import torch
-
增减维度unsqueeze 、squeeze
a=torch.Tensor([[1,2,3],[4,5,6]])
print(a.size())#torch.Size([2, 3])
b=a.unsqueeze(2)#括号里为维度增加的位置
print(b.size())#torch.Size([2, 3, 1])
c=b.squeeze(-1)#去维度
print(c.size())#torch.Size([2, 3])
-
Tensor连接cat
a=torch.Tensor([[1,2,3],[4,5,6]])
d=torch.ones([2,3])
e=torch.cat([a,d],1)
print(e)
print(torch.cat([a **i for i in range(3)],1))
#result:
# tensor([[1., 2., 3., 1., 1., 1.],
# [4., 5., 6., 1., 1., 1.]])
# tensor([[ 1., 1., 1., 1., 2., 3., 1., 4., 9.],
# [ 1., 1., 1., 4., 5., 6., 16., 25., 36.]])
-
大于某个值ge
a=torch.Tensor([[1,2,3],[4,5,6]])
b=torch.Tensor([0.3,0.40,0.6,0.8,0.1])
print(b.ge(0.5))#tensor([0, 0, 1, 1, 0], dtype=torch.uint8)
print(b.ge(0.5).float())#tensor([0., 0., 1., 1., 0.])
cor=((b.ge(0.5).float()==torch.ones(5)).sum())
print(cor)#tensor(2)
print(b.size(),b.size(0))#torch.Size([5]) 5