1 使用 torch.cat 拼接
略略略
2 使用 torch.stack 拼接
mx = torch.ones(( 3, 2))
print(mx, mx.shape)
t1 = torch.stack([mx, mx], dim= 2) #在新創建的維度上進行拼接
print(t1, t1.shape) #拼接完會從2維變成3維
mx = torch.ones(( 1, 2))
print(mx, mx.shape)
t1 = torch.stack([mx, mx], dim= 2) #在新創建的維度上進行拼接
print(t1, t1.shape)
mx = torch.ones(( 3, 2))
print(mx, mx.shape)
t1 = torch.stack(mx, dim= 2) # Error
print(t1, t1.shape)
mx = torch.ones(( 1, 2))
print(mx, mx.shape)
t1 = torch.stack(mx, dim= 2) # Error
print(t1, t1.shape)
output,_ = self.lstm(inputs)
ds = self.dense(output)
ds_stack = torch.stack(ds, dim=2)