原博文地址:https://blog.csdn.net/weixin_41278720/article/details/80759933
Pytorch中的torchvision包又包括3個子包,分別如下:
torchvison.datasets:預定義好的數據集(比如MNIST、CIFAR10等)
torchvision.models :預定義好的經典網絡結構(比如AlexNet、VGG、ResNet等)
torchvision.transforms :預定義好的數據增強方法(比如Resize、ToTensor等)
models這個包中包含alexnet、densenet、inception、resnet、squeezenet、vgg等常用的網絡結構,並且提供了預訓練模型,可以通過簡單調用來讀取網絡結構和預訓練模型。
1、加載官方提供的網絡模型
加載resnet50預訓練模型(包含訓練得到的權重與偏置參數)
import torchvision.models as models
resnet50 = models.resnet50(pretrained=True)
只加載resnet50網絡結構,並未使用預訓練模型的參數對其初始化(權重與偏置都是隨機值)
import torchvision.models as models
resnet50 = models.resnet50(pretrained=False)
2、保存、加載自己的網絡模型
方法一(推薦):只保存和加載模型中的參數,不保存其網絡結構
保存:將訓練參數保存在ckp文件夾中,文件名:model.pth
torch.save(resnet50.state_dict(),'ckp/model.pth')
加載:這裏的resnet50是我們自己實現的網絡,因此可以不必傳遞pretrained=True參數(官方提供的版本需要傳遞此參數)
resnet=resnet50() #加載網絡結構
resnet.load_state_dict(torch.load('ckp/model.pth')) #加載該網絡結構的預訓練參數
方法二:保存、加載網絡的結構和參數信息
保存
torch.save(resnet50,'model.pth')
加載
resnet50 = torch.load('model.pth')
方法三:選擇保存、加載網絡中的一部分參數或者保存額外的參數
保存
save_name = os.path.join(output_dir, 'faster_rcnn_{}_{}_{}.pth'.format(args.session, epoch, step))
torch.save({
'session': args.session,
'epoch': epoch + 1,
'model': fasterRCNN.module.state_dict() if args.mGPUs else fasterRCNN.state_dict(),
'optimizer': optimizer.state_dict(),
'pooling_mode': cfg.POOLING_MODE,
'class_agnostic': args.class_agnostic,
}, save_name)
加載
load_name = os.path.join(output_dir,
'faster_rcnn_{}_{}_{}.pth'.format(args.checksession, args.checkepoch, args.checkpoint))
checkpoint = torch.load(load_name)
args.session = checkpoint['session']
args.start_epoch = checkpoint['epoch']
fasterRCNN.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr = optimizer.param_groups[0]['lr']
cfg.POOLING_MODE = checkpoint['pooling_mode']