Pytorch學習(四) --- 模型的保存和加載

Pytorch提供了兩種方法進行模型的保存和加載。

第一種(推薦):
該方法值保存和加載模型的參數

# 保存
torch.save(the_model.state_dict(), PATH)
# 加載
# 定義模型
the_model = TheModelClass(*args, **kwargs)
# 加載模型
the_model.load_state_dict(torch.load(PATH))

例如:

import torch
import torchvision.models as models
# 創建模型
model = models.resnet101().cuda()
'''
訓練過程
'''
# 保存訓練後的模型
torch.save(model.state_dict(), './resnet101_test.pt'.)

第二種:
保存和加載整個模型。

# 保存
torch.save(the_model, PATH)
# 加載
the_model = torch.load(PATH)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章