pytorch中保存和加載模型是綁在一起的。
這裏我需要注意一下不同的保存方式對應不同的讀取方式,兩者各有利弊。
首先說說pytorch.save()這個函數,可以參考官網:pytroch.save。
簡而言之,這個函數可以保存任意的東西,比如tensor或者模型,或者僅僅是模型的參數。
如果將保存對象侷限在模型上,通常來說我們有兩種方式:直接保存所有的模型,只保存模型中的參數(模型結構就保存了)。以下分別說說兩種不同的方式。
爲了說明,我們先建立一個簡單的模型。
import torch
import torch.nn as nn
class Generator(nn.Module):
def __init__(self, in_c, out_c, ngf=64):
super(Generator, self).__init__()
model = []
model += [
nn.Conv2d(in_c, ngf, 3, 2, 1),
nn.ReLU(),
nn.BatchNorm2d(ngf),
nn.Conv2d(ngf, out_c, 3, 2, 1)
]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
netG = Generator(3, 3)
input = torch.zeros(10, 3, 256, 256)
output = netG(input)
直接保存所有模型並讀取
直接使用簡單粗暴的方式保存:
torch.save(netG, 'netG.pt')
對應的,我們可以這樣讀取模型
netC = torch.load('netG.pt')
input = torch.zeros(10, 3, 256, 256)
output = netC(input)
正常情況如下(警告先忽略):
只保存模型中的參數並讀取
我們說模型的參數保存在網絡的state_dict中,使用這個就可以讀取網絡的參數了。
torch.save({'netG': netG.state_dict()}, 'model_test.pt')
對應的加載模型的方式如下:
netD = Generator(3, 3)
state_dict = torch.load('model_test.pt')
netD.load_state_dict(state_dict['netG'])
input = torch.zeros(10, 3, 256, 256)
output = netD(input)
總結
我們可以看到第一種方法可以直接保存模型,加載模型的時候直接把讀取的模型給一個參數就行。而第二種方法則只是保存參數,在讀取模型參數前要先定義一個模型(模型必須與原模型相同的構造),然後對這個模型導入參數。雖然麻煩,但是可以同時保存多個模型的參數,而第一種方法則不能,而且第一種方法有時不能保證模型的相同性(你讀取的模型並不是你想要的)。
總的來說,我們一般來選擇第二種來保存和讀取。
退一步講,如何保存模型決定了如何讀取模型。