Pytorch保存與加載模型

原博文地址:https://blog.csdn.net/weixin_41278720/article/details/80759933

Pytorch中的torchvision包又包括3個子包,分別如下:

torchvison.datasets:預定義好的數據集(比如MNIST、CIFAR10等)

torchvision.models :預定義好的經典網絡結構(比如AlexNet、VGG、ResNet等)

torchvision.transforms :預定義好的數據增強方法(比如Resize、ToTensor等)

models這個包中包含alexnet、densenet、inception、resnet、squeezenet、vgg等常用的網絡結構,並且提供了預訓練模型,可以通過簡單調用來讀取網絡結構和預訓練模型。

1、加載官方提供的網絡模型

加載resnet50預訓練模型(包含訓練得到的權重與偏置參數)

import torchvision.models as models
 
resnet50 = models.resnet50(pretrained=True)

只加載resnet50網絡結構,並未使用預訓練模型的參數對其初始化(權重與偏置都是隨機值)

import torchvision.models as models
 
resnet50 = models.resnet50(pretrained=False)

2、保存、加載自己的網絡模型

方法一(推薦):只保存和加載模型中的參數,不保存其網絡結構

保存:將訓練參數保存在ckp文件夾中,文件名:model.pth

torch.save(resnet50.state_dict(),'ckp/model.pth') 

加載:這裏的resnet50是我們自己實現的網絡,因此可以不必傳遞pretrained=True參數(官方提供的版本需要傳遞此參數)

resnet=resnet50()    #加載網絡結構
resnet.load_state_dict(torch.load('ckp/model.pth'))  #加載該網絡結構的預訓練參數

方法二:保存、加載網絡的結構和參數信息

保存

torch.save(resnet50,'model.pth') 

加載

resnet50 = torch.load('model.pth')

方法三:選擇保存、加載網絡中的一部分參數或者保存額外的參數

保存

save_name = os.path.join(output_dir, 'faster_rcnn_{}_{}_{}.pth'.format(args.session, epoch, step))
torch.save({
      'session': args.session,
      'epoch': epoch + 1,
      'model': fasterRCNN.module.state_dict() if args.mGPUs else fasterRCNN.state_dict(),
      'optimizer': optimizer.state_dict(),
      'pooling_mode': cfg.POOLING_MODE,
      'class_agnostic': args.class_agnostic,
    }, save_name)

加載

load_name = os.path.join(output_dir,
      'faster_rcnn_{}_{}_{}.pth'.format(args.checksession, args.checkepoch, args.checkpoint))
checkpoint = torch.load(load_name)
args.session = checkpoint['session']
args.start_epoch = checkpoint['epoch']
fasterRCNN.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr = optimizer.param_groups[0]['lr']
cfg.POOLING_MODE = checkpoint['pooling_mode']

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章