PyTorch載入模型,並輸出參數

# define the model
model = Model()
for k, v in model.named_parameters():
    print(k, v.size())
#===================================#
pretrained_state = torch.load('pretrained_model.pth')
for i in pretrained_state:
    print(i, pretrained_state[i].size())
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章