PyTorch:模型參數讀取與設置--以FlowNetSimple爲例

一、背景

在“搞”深度學習時,除非富如東海,往往都不會直接用大量數據來訓練一個網絡;一般情況下,比較省錢且高效的思路是利用一些預訓練的模型,並在其基礎上進行再訓練優化,達到自己的目的。
因此,在本博客中將簡單記錄一下,如何在PyTorch基礎上讀取預訓練模型的參數,並添加到自己的模型中去,從而儘可能減少自己的計算量。
爲了直接講明整個過程,本文設計了一個實驗,首先設計了一個網絡,其前半部分與FlowNetSimple的Encode一致,後半部分是全連接的分類網絡。
下圖是FlowNetSimple的網絡結構,其中的refinement部分是Decode結構(類似UNet)
在這裏插入圖片描述
本文設計的結構,其實就是把Decode給刪除了,換成全連接,代碼就不貼了,很容易。
在這裏插入圖片描述


二、參考鏈接

https://github.com/NVIDIA/flownet2-pytorch
《Dive into DL PyTorch》


三、操作過程

3.1 下載預訓練模型

此處我用的預訓練模型來自https://github.com/NVIDIA/flownet2-pytorch此網頁下的FlowNetS
如果不是很瞭解,FlowNetSimple,其對應的代碼如下,我簡單註釋一下

Learn more or give us feedback
'''
Portions of this code copyright 2017, Clement Pinard
'''

import torch
import torch.nn as nn
from torch.nn import init

import math
import numpy as np

from .submodules import *
'Parameter count : 38,676,504 '

class FlowNetS(nn.Module):
   def __init__(self, args, input_channels = 12, batchNorm=True):
       super(FlowNetS,self).__init__()

   	# 以下全部都是Encode部分,conv是這個代碼自行封裝的,等同於conv2d+ReLU
       self.batchNorm = batchNorm
       self.conv1   = conv(self.batchNorm,  input_channels,   64, kernel_size=7, stride=2)
       self.conv2   = conv(self.batchNorm,  64,  128, kernel_size=5, stride=2)
       self.conv3   = conv(self.batchNorm, 128,  256, kernel_size=5, stride=2)
       self.conv3_1 = conv(self.batchNorm, 256,  256)
       self.conv4   = conv(self.batchNorm, 256,  512, stride=2)
       self.conv4_1 = conv(self.batchNorm, 512,  512)
       self.conv5   = conv(self.batchNorm, 512,  512, stride=2)
       self.conv5_1 = conv(self.batchNorm, 512,  512)
       self.conv6   = conv(self.batchNorm, 512, 1024, stride=2)
       self.conv6_1 = conv(self.batchNorm,1024, 1024)

      # 以下是Decode部分,deconv是向上卷積部分
       self.deconv5 = deconv(1024,512)
       self.deconv4 = deconv(1026,256)
       self.deconv3 = deconv(770,128)
       self.deconv2 = deconv(386,64)

       # 這些部分不用關心,這是用於預測光流的
       self.predict_flow6 = predict_flow(1024)
       self.predict_flow5 = predict_flow(1026)
       self.predict_flow4 = predict_flow(770)
       self.predict_flow3 = predict_flow(386)
       self.predict_flow2 = predict_flow(194)
       self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
       self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
       self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
       self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)

      # 初始化
       for m in self.modules():
           if isinstance(m, nn.Conv2d):
               if m.bias is not None:
                   init.uniform_(m.bias)
               init.xavier_uniform_(m.weight)

           if isinstance(m, nn.ConvTranspose2d):
               if m.bias is not None:
                   init.uniform_(m.bias)
               init.xavier_uniform_(m.weight)
               # init_deconv_bilinear(m.weight)
       self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear')

   def forward(self, x):
       out_conv1 = self.conv1(x)

       # Encode
       out_conv2 = self.conv2(out_conv1)
       out_conv3 = self.conv3_1(self.conv3(out_conv2))
       out_conv4 = self.conv4_1(self.conv4(out_conv3))
       out_conv5 = self.conv5_1(self.conv5(out_conv4))
       out_conv6 = self.conv6_1(self.conv6(out_conv5))

      # 這個地方如果看過FlowNet的論文就知道,Decode的每一層都會返回一個光流結果,不同尺寸的
       flow6       = self.predict_flow6(out_conv6)
       flow6_up    = self.upsampled_flow6_to_5(flow6)
       out_deconv5 = self.deconv5(out_conv6)
       
       concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1)
       flow5       = self.predict_flow5(concat5)
       flow5_up    = self.upsampled_flow5_to_4(flow5)
       out_deconv4 = self.deconv4(concat5)
       
       concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1)
       flow4       = self.predict_flow4(concat4)
       flow4_up    = self.upsampled_flow4_to_3(flow4)
       out_deconv3 = self.deconv3(concat4)
       
       concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1)
       flow3       = self.predict_flow3(concat3)
       flow3_up    = self.upsampled_flow3_to_2(flow3)
       out_deconv2 = self.deconv2(concat3)

       concat2 = torch.cat((out_conv2,out_deconv2,flow3_up),1)
       flow2 = self.predict_flow2(concat2)

       if self.training:
           return flow2,flow3,flow4,flow5,flow6
       else:
           return flow2,

3.2 查看預模型參數

千辛萬苦下載好預訓練模型後,接下來就可以開始準備獲取預訓練參數了。不過,首先讓我們簡單看一下,這個預訓練結果裏都有啥。提醒一下,以上鍊接下載下來的文件名稱爲FlowNet2-S_checkpoint.pth.tar

3.2.1 讀取模型
# 讀取預訓練模型並不一定要先聲明model,完全可以先讀取內容
state_dict = torch.load('FlowNet2-S_checkpoint.pth.tar')
3.2.2 打印state_dict 信息
# 這裏補充一下,我的感覺是PyTorch是用類似json序列化的方式在保存模型,所以其核心就是key-value
for k, v in state_dict.items():
    print(k)

輸出

epoch
best_EPE
state_dict

看到這個結果,我的第一反應其實是懵逼的;但是很快反應過來,其中epochbest_EPE存儲了訓練時的一些信息,這些我們並不感興趣。state_dict纔是我們真正感興趣的。

3.2.3 打印state_dict[‘state_dict’]信息
for k, v in state_dict['state_dict'].items():
    print(k)

輸出

conv1.0.weight
conv1.0.bias
conv2.0.weight
conv2.0.bias
conv3.0.weight
conv3.0.bias
conv3_1.0.weight
conv3_1.0.bias
conv4.0.weight
conv4.0.bias
conv4_1.0.weight
conv4_1.0.bias
conv5.0.weight
conv5.0.bias
conv5_1.0.weight
conv5_1.0.bias
conv6.0.weight
conv6.0.bias
conv6_1.0.weight
conv6_1.0.bias
deconv5.0.weight
deconv5.0.bias
deconv4.0.weight
deconv4.0.bias
deconv3.0.weight
deconv3.0.bias
deconv2.0.weight
deconv2.0.bias
predict_flow6.weight
predict_flow6.bias
predict_flow5.weight
predict_flow5.bias
predict_flow4.weight
predict_flow4.bias
predict_flow3.weight
predict_flow3.bias
predict_flow2.weight
predict_flow2.bias
upsampled_flow6_to_5.weight
upsampled_flow5_to_4.weight
upsampled_flow4_to_3.weight
upsampled_flow3_to_2.weight

不難發現,上述輸出的key與FlowNetSimple的模型一一對應,這爲我們後續讀取打好了基礎。

3.3 查看設計模型參數

此處沒啥好說的,不過仍舊可以訪問一下定義好的模型,看看都有啥參數。

for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)

輸出

conv1.0.weight
conv1.0.bias
conv2.0.weight
conv2.0.bias
conv3.0.weight
conv3.0.bias
conv3_1.0.weight
conv3_1.0.bias
conv4.0.weight
conv4.0.bias
conv4_1.0.weight
conv4_1.0.bias
conv5.0.weight
conv5.0.bias
conv5_1.0.weight
conv5_1.0.bias
conv6.0.weight
conv6.0.bias
conv6_1.0.weight
conv6_1.0.bias
fc_1.0.weight
fc_1.0.bias
fc_2.0.weight
fc_2.0.bias

不難看出,其大部分結構與FlowNetSimple相同,唯一不同的最後的兩個全連接層。

3.4 模型參數賦值

一般情況下,如果預訓練模型和自己訓練的模型完全相同,那麼直接model.load_state_dict(torch.load(PATH))即可。但是在此處呢,預訓練模型和自己的模型不一致,這意味着我們僅僅需要其中一部分參數,另外的則不感興趣。此處,博主用了一個非常懶的方式,即逐個對對應的模塊進行賦值。

3.4.1 訪問預訓練參數

對於上述讀取的state_dict,其模型的參數也是key-value的模式,讀取方式如下:

# 後部分的參數可以改成對應的層
state_dict['state_dict']['conv2.0.bias']
3.4.2 訪問自己模型的參數

在本代碼中,定義好的模型每一層是其的一個屬性,因此其訪問模型如下:

# 這裏其實也有我用了Sequential的原因,但思路差不多
model.conv2[0].bias.data
3.4.3 賦值

原諒如此暴力的我!

model.conv1[0].weight.data = state_dict['state_dict']['conv1.0.weight']
model.conv1[0].bias.data = state_dict['state_dict']['conv1.0.bias']
model.conv2[0].weight.data = state_dict['state_dict']['conv2.0.weight']
model.conv2[0].bias.data = state_dict['state_dict']['conv2.0.bias']

model.conv3[0].weight.data = state_dict['state_dict']['conv3.0.weight']
model.conv3[0].bias.data = state_dict['state_dict']['conv3.0.bias']
model.conv3_1[0].weight.data = state_dict['state_dict']['conv3_1.0.weight']
model.conv3_1[0].bias.data = state_dict['state_dict']['conv3_1.0.bias']

model.conv4[0].weight.data = state_dict['state_dict']['conv4.0.weight']
model.conv4[0].bias.data = state_dict['state_dict']['conv4.0.bias']
model.conv4_1[0].weight.data = state_dict['state_dict']['conv4_1.0.weight']
model.conv4_1[0].bias.data = state_dict['state_dict']['conv4_1.0.bias']

model.conv5[0].weight.data = state_dict['state_dict']['conv5.0.weight']
model.conv5[0].bias.data = state_dict['state_dict']['conv5.0.bias']
model.conv5_1[0].weight.data = state_dict['state_dict']['conv5_1.0.weight']
model.conv5_1[0].bias.data = state_dict['state_dict']['conv5_1.0.bias']

model.conv6[0].weight.data = state_dict['state_dict']['conv6.0.weight']
model.conv6[0].bias.data = state_dict['state_dict']['conv6.0.bias']
model.conv6_1[0].weight.data = state_dict['state_dict']['conv6_1.0.weight']
model.conv6_1[0].bias.data = state_dict['state_dict']['conv6_1.0.bias']

總結

本博客只是提供了一種思路,不一定是最好的,但是目前對我管用。

發佈了151 篇原創文章 · 獲贊 160 · 訪問量 44萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章