深度學習--第15篇: Pytorch保存和加載模型參數

參考博客

參考博客: https://blog.csdn.net/lscelory/article/details/81482586

pytorch的模型和參數是分開的,可以分別保存或加載模型和參數。
pytorch有兩種模型保存方式:

  • 保存整個神經網絡的的結構信息和模型參數信息,save的對象是網絡net
  • 只保存神經網絡的訓練模型參數,save的對象是net.state_dict()

對應兩種保存模型的方式,pytorch也有兩種加載模型的方式。對應第一種保存方式,加載模型時通過torch.load(’.pth’)直接初始化新的神經網絡對象;對應第二種保存方式,需要首先導入對應的網絡,再通過net.load_state_dict(torch.load(’.pth’))完成模型參數的加載。

在網絡比較大的時候,第一種方法會花費較多的時間。

1. 保存模型和參數

  • 保存模型
# 將網絡結構和模型參數都保存起來,在測試時可以直接加載,不需要初始化網絡結構
torch.save(model, path)

參數:
	model: 訓練的網絡
	pth: 保存的路徑(包含文件名,後綴名以.pth .pkl 等結尾)

實例:
torch.save(model, os.path.join('.', 'lenet.pth')) # 保存模型結構和參數
  • 加載模型
# 直接加載模型文件,不需要初始化網絡結構
model = torch.load(path)

參數:
	model: 加載後的網絡
	pth: 保存模型文件的路徑(包含文件名,後綴名以.pth .pkl 等結尾)

實例:
model = torch.load(os.path.join('.', 'lenet.pth'))# 加載模型結構和參數

2. 僅保存參數

  • 保存參數
# 將lenet模型儲存爲lenet.pth, 注意保存的僅僅是網絡模型的狀態信息參數字典, 加載是需要初始化網絡模型
torch.save(net.state_dict(), os.path.join('.', 'lenet.pth'))
  • 加載參數
# 加載lenet,模型存放在lenet.pth, 加載之前要確認網絡模型已初始化完成
model = torch.load(os.path.join('.', 'lenet.pth'))
net.load_state_dict(model)

3. 加載pytorch預訓練模型

3.1 加載預訓練模型和參數

import torchvision
AlexNet = torchvision.models.alexnet(pretrained=True) # 加載預訓練模型AlexNet和參數

resnet18 = torchvision.models.resnet18(pretrained=True)

3.2 只加載模型不加載預訓練參數

import torchvision
AlexNet = torchvision.models.alexnet(pretrained=False) # 加載預訓練模型AlexNet

# 導入模型結構
ResNet18 = models.resnet18(pretrained=False)
# 加載預先下載好的預訓練參數到resnet18
ResNet18.load_state_dict(torch.load('resnet18-5c106cde.pth'))
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章