pytorch保存和加載模型的兩種方式

pytorch中保存和加載模型是綁在一起的。
這裏我需要注意一下不同的保存方式對應不同的讀取方式,兩者各有利弊。

首先說說pytorch.save()這個函數,可以參考官網:pytroch.save
簡而言之,這個函數可以保存任意的東西,比如tensor或者模型,或者僅僅是模型的參數。
如果將保存對象侷限在模型上,通常來說我們有兩種方式直接保存所有的模型,只保存模型中的參數(模型結構就保存了)。以下分別說說兩種不同的方式。

爲了說明,我們先建立一個簡單的模型。

import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, in_c, out_c, ngf=64):
        super(Generator, self).__init__()
        model = []
        model += [
            nn.Conv2d(in_c, ngf, 3, 2, 1),
            nn.ReLU(),
            nn.BatchNorm2d(ngf),
            nn.Conv2d(ngf, out_c, 3, 2, 1)
        ]
        self.model = nn.Sequential(*model)
        
    def forward(self, x):
        return self.model(x)

netG = Generator(3, 3)
input = torch.zeros(10, 3, 256, 256)
output = netG(input)

直接保存所有模型並讀取

直接使用簡單粗暴的方式保存:

torch.save(netG, 'netG.pt')

對應的,我們可以這樣讀取模型

netC = torch.load('netG.pt')
input = torch.zeros(10, 3, 256, 256)
output = netC(input)

正常情況如下(警告先忽略):在這裏插入圖片描述

只保存模型中的參數並讀取

我們說模型的參數保存在網絡的state_dict中,使用這個就可以讀取網絡的參數了。

torch.save({'netG': netG.state_dict()}, 'model_test.pt')

對應的加載模型的方式如下:

netD = Generator(3, 3)
state_dict = torch.load('model_test.pt')
netD.load_state_dict(state_dict['netG'])
input = torch.zeros(10, 3, 256, 256)
output = netD(input)

總結

我們可以看到第一種方法可以直接保存模型,加載模型的時候直接把讀取的模型給一個參數就行。而第二種方法則只是保存參數,在讀取模型參數前要先定義一個模型(模型必須與原模型相同的構造),然後對這個模型導入參數。雖然麻煩,但是可以同時保存多個模型的參數,而第一種方法則不能,而且第一種方法有時不能保證模型的相同性(你讀取的模型並不是你想要的)。

總的來說,我們一般來選擇第二種來保存和讀取
退一步講,如何保存模型決定了如何讀取模型

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章