pytorch保存模型

一:只保存和加載模型參數
1 . 保存模型參數:

import torch
torch.save(model.state_dict(), 'save_path_name.pth')

2 . 加載模型參數:

import torch
import torch.nn as nn
model.load_state_dict(torch.load('save_path_name.pth'), strict=True)

方式二:保存和加載整個模型(模型結構和模型參數)
1 . 保存模型:

import torch
torch.save(model, 'save_path_name.pth')

2 . 加載模型:

import torch
import torch.nn as nn
model = torch.load('save_path_name.pth')




# 保存模型到路徑
torch.save(Batch_Net(28*28, 300, 100, 10), r'C:\Users\11868\Desktop\net.pth')
# 保存模型的參數
torch.save(model.state_dict(), r'C:\Users\11868\Desktop\state_dict.pth')
————————————————
版權聲明:本文爲CSDN博主「Answerlzd」的原創文章,遵循 CC 4.0 BY-SA 版權協議,轉載請附上原文出處鏈接及本聲明。
原文鏈接:https://blog.csdn.net/Answer3664/article/details/98084300
# 加載模型
model = torch.load(r'C:\Users\11868\Desktop\net.pth')
# 加載參數
model.load_state_dict(torch.load(r'C:\Users\11868\Desktop\state_dict.pth'))
 
model.eval() # 將模型改爲測試模式
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章