張量的拼接(與切分) torch.stack()

1 使用 torch.cat 拼接

     略略略

 


2 使用 torch.stack 拼接

 


mx = torch.ones(( 3, 2))
print(mx, mx.shape)  
t1 = torch.stack([mx, mx], dim= 2)  #在新創建的維度上進行拼接
print(t1, t1.shape)                 #拼接完會從2維變成3維


mx = torch.ones(( 1, 2))
print(mx, mx.shape)  
t1 = torch.stack([mx, mx], dim= 2)  #在新創建的維度上進行拼接
print(t1, t1.shape)       



mx = torch.ones(( 3, 2))
print(mx, mx.shape)  
t1 = torch.stack(mx, dim= 2)   # Error
print(t1, t1.shape)           


mx = torch.ones(( 1, 2))
print(mx, mx.shape)  
t1 = torch.stack(mx, dim= 2)   # Error
print(t1, t1.shape)     

 

 


output,_ = self.lstm(inputs)
ds = self.dense(output)
ds_stack = torch.stack(ds, dim=2)

 

 

 

 

 

 

 

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章