Pytorch學習記錄

1.保存ConvNet

使用torch.save()對網絡結構和模型參數的保存有兩種保存方式:

  • 保存整個神經網絡的結構信息和模型參數信息,save的對象是網絡net;
  • 保存神經網絡的訓練模型參數,save的對象是net.state_dict().
torch.save(net, 'net.pkl')  # 保存整個神經網絡的結構和模型參數
torch.save(net.state_dict(), 'net_params.pkl')  # 只保存神經網絡的模型參數

2.加載ConvNet

對應上面兩種保存方式, 重載方式也有兩種。

  • 對應第一種完整網絡結構信息,重載的時候通過torch.load(‘.pth’)直接初始化新的神經網絡對象即可。
  • 對應第二種只保存模型參數信息,需要首先導入對應的網絡,通過net.load_state_dict(torch.load(‘.pth’))完成模型的重載。
    在網絡比較大的時候,第一種方法會話費較多的時間,所佔的存儲空間也比較大。
# 保存和加載整個模型
torch.save(model_object, 'model.pth')
model = torch.load('model.pth')

# 保存和加載模型參數
torch.save(model_object.state_dict(), 'params.pth')
model_object.load_state_dict(torch.load('params.pth'))
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章