Pytorch提供了兩種方法進行模型的保存和加載。
第一種(推薦):
該方法值保存和加載模型的參數
# 保存
torch.save(the_model.state_dict(), PATH)
# 加載
# 定義模型
the_model = TheModelClass(*args, **kwargs)
# 加載模型
the_model.load_state_dict(torch.load(PATH))
例如:
import torch
import torchvision.models as models
# 創建模型
model = models.resnet101().cuda()
'''
訓練過程
'''
# 保存訓練後的模型
torch.save(model.state_dict(), './resnet101_test.pt'.)
第二種:
保存和加載整個模型。
# 保存
torch.save(the_model, PATH)
# 加載
the_model = torch.load(PATH)