pytorch軸

代碼:
import torch

class_num = 10
batch_size = 4
label = torch.Tensor(2,3,4,5).random_() % class_num
print(label.size())
label=label.permute(1,0,2,3)
print(label.size())
輸出:
torch.Size([2, 3, 4, 5])
torch.Size([3, 2, 4, 5])

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章