如何保存和读取pytorch模型

如何保存和读取pytorch模型


相信大家也会遇到这样的问题吧,在使用pytorch训练自己模型的时候,如果不将我们训练的模型保存起来,我们每一次都是从头开始训练我们的模型,这样真的很麻烦。其实在我的上一篇博客中我已经发现这个问题了。

1.保存模型

#定义保存模型函数
def save_model(the_model,PATH):
    torch.save(the_model.state_dict(),PAT

当我们的模型训练完毕之后,我们只需调用一下该函数就可以了

save_model(cnn,'cnn.pth')
#这里的cnn就是我要保存的训练好的模型,cnn.pth就是要保存为的名称,
#一般来说pytorch的模型后缀都是.pth

2.读取模型

例如我们想要在另外的一个python文件中读取我们之前已经保存好的模型,我们需要先创建一个和之前模型一样的空模型来接收。

import torch
from cnn_test import CNN

best_model=CNN()
#定义一个与之前模型一致结构的模型来接收
best_model.load_state_dict(torch.load('cnn.pth'))
#加载之前的模型,这里的‘cnn.pth’就是我上一步保存的模型文件
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章