1.保存ConvNet
使用torch.save()
對網絡結構和模型參數的保存有兩種保存方式:
- 保存整個神經網絡的結構信息和模型參數信息,save的對象是網絡net;
- 保存神經網絡的訓練模型參數,save的對象是net.state_dict().
torch.save(net, 'net.pkl') # 保存整個神經網絡的結構和模型參數
torch.save(net.state_dict(), 'net_params.pkl') # 只保存神經網絡的模型參數
2.加載ConvNet
對應上面兩種保存方式, 重載方式也有兩種。
- 對應第一種完整網絡結構信息,重載的時候通過torch.load(‘.pth’)直接初始化新的神經網絡對象即可。
- 對應第二種只保存模型參數信息,需要首先導入對應的網絡,通過net.load_state_dict(torch.load(‘.pth’))完成模型的重載。
在網絡比較大的時候,第一種方法會話費較多的時間,所佔的存儲空間也比較大。
# 保存和加載整個模型
torch.save(model_object, 'model.pth')
model = torch.load('model.pth')
# 保存和加載模型參數
torch.save(model_object.state_dict(), 'params.pth')
model_object.load_state_dict(torch.load('params.pth'))