pytorch 加載已訓練好的(.pth)格式模型--轉載

1 簡介

pytorch裏有一些非常流行的網絡如 resnet、wide_resnet101_2、squeezenet、densenet等,包括網絡結構和訓練好的模型。
pytorch自帶模型網址:https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-models/

按照官網加載預訓練的模型:

import torchvision.models as models

# pretrained=True就可以使用預訓練的模型
resnet18 = models.resnet18(pretrained=True)
print(resnet18)

  • 1
  • 2
  • 3
  • 4
  • 5

可能會出現以下錯誤:

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to C:\Users\Administrator/.cache\torch\checkpoints\resnet18-5c106cde.pth
  • 1

主要原因是:國內的網有些時候連接不上,需要我們手動去下載想要的預訓練網絡。

2 下載相應模型

可以從報錯代碼中複製網址進行下載:
https://download.pytorch.org/models/resnet18-5c106cde.pth

也可以從Pytorch的github下找模型的地址:
https://github.com/pytorch/vision/tree/master/torchvision/models
找到對應模型名稱點進去找地址
在這裏插入圖片描述

3 加載已保存的模型

下載好後自行保存,保存到自己能找到的地址。

接下來就是運行這個.pth文件。首先要判斷是保存的整個網絡結構加參數呢,還是隻保存了參數,可以測試一下。這是我的模型是resnet34,你可以測試自己下載的模型

import torch
import torchvision.models as models

# pretrained=True就可以使用預訓練的模型
net = models.resnet34(pretrained=False)
pthfile = r’C:\Users\Administrator\tianchi\model\resnet34-333f7ec4.pth’
net.load_state_dict(torch.load(pthfile))
print(net)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

結果:
在這裏插入圖片描述
這樣就加載好預訓練模型了

參考鏈接:

https://tianchi.aliyun.com/notebook-ai/notebookEdit?notebookLabId=96053&version=0

https://blog.csdn.net/u014264373/article/details/85332181

網盤下載鏈接:

鏈接:https://pan.baidu.com/s/12jdjQCeT0xYH7OLciMUi4w
提取碼:ly9d

                                </div>
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章