pytorch(三)——Tensor的運算

注:以下均基於導入庫

import torch
  • 增減維度unsqueeze 、squeeze

a=torch.Tensor([[1,2,3],[4,5,6]])
print(a.size())#torch.Size([2, 3])

b=a.unsqueeze(2)#括號裏爲維度增加的位置
print(b.size())#torch.Size([2, 3, 1])

c=b.squeeze(-1)#去維度
print(c.size())#torch.Size([2, 3])
  • Tensor連接cat

a=torch.Tensor([[1,2,3],[4,5,6]])
d=torch.ones([2,3])
e=torch.cat([a,d],1)
print(e)
print(torch.cat([a **i for i in range(3)],1))
#result:
# tensor([[1., 2., 3., 1., 1., 1.],
#         [4., 5., 6., 1., 1., 1.]])
# tensor([[ 1.,  1.,  1.,  1.,  2.,  3.,  1.,  4.,  9.],
#         [ 1.,  1.,  1.,  4.,  5.,  6., 16., 25., 36.]])
  • 大於某個值ge

a=torch.Tensor([[1,2,3],[4,5,6]])
b=torch.Tensor([0.3,0.40,0.6,0.8,0.1])
print(b.ge(0.5))#tensor([0, 0, 1, 1, 0], dtype=torch.uint8)
print(b.ge(0.5).float())#tensor([0., 0., 1., 1., 0.])
cor=((b.ge(0.5).float()==torch.ones(5)).sum())
print(cor)#tensor(2)
print(b.size(),b.size(0))#torch.Size([5]) 5

 

 

 

 

 

 

 

 

 

 

 

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章