mxnet.gluon-blocks.save_params()

定義一個resnet-18網絡

import gluonbook as gb
from mxnet.gluon import Trainer,data as gdata, nn
from mxnet import init, nd
import os
import sys


class Residual(nn.Block):  # 本類已保存在 gluonbook 包中方便以後使用。
    def __init__(self, num_channels, use_1x1conv=False, strides=1, **kwargs):
        super(Residual, self).__init__(**kwargs)
        self.conv1 = nn.Conv2D(num_channels, kernel_size=3, padding=1,
                               strides=strides)
        self.conv2 = nn.Conv2D(num_channels, kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2D(num_channels, kernel_size=1,
                                   strides=strides)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm()
        self.bn2 = nn.BatchNorm()

    def forward(self, X):
        Y = nd.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        return nd.relu(Y + X)

def resnet_block(num_channels, num_residuals, first_block =False):
    blk = nn.Sequential()
    for i in range(num_residuals):
        if i==0 and not first_block:
            blk.add(Residual(num_channels,use_1x1conv=True,strides=2))
        else:
            blk.add(Residual(num_channels))
    return blk


net = nn.Sequential()
net.add(nn.Conv2D(64,kernel_size=11,padding=3,strides=2),
        nn.BatchNorm(),
        nn.Activation('relu'),
        nn.MaxPool2D(pool_size=3,strides=2,padding=1),
        resnet_block(64,4,first_block=True),
        resnet_block(128,4),
        resnet_block(256,4),
        resnet_block(512,4),
        nn.GlobalAvgPool2D(),
        nn.Dense(1024),
        nn.BatchNorm(),
        nn.Activation('relu'),
        nn.Dropout(0.4),
        nn.Dense(10))

隨機初始化網絡並且保存數據,使用block類自帶的save_parameters() 成員函數:

net.initialize(force_reinit=True,ctx=ctx,init=init.Xavier())
new_filename = 'tmp.params'
net.save_parameters(new_filename)

然後就是讀進來數據,分析保存的結構

#load params for analyzation
params = nd.load('tmp.params')
#params is a dict
print(isinstance(params,dict))
#print dict members'names
for key in params:
    print(key)

nd.load的結果是一個字典,字典的keys的打印結果如下:

  • 0.weight
  • 0.bias
  • 1.gamma
  • 1.beta
  • 1.running_mean
  • 1.running_var
  • 4.0.conv1.weight
  • 4.0.conv1.bias
  • 4.0.conv2.weight
  • 4.0.conv2.bias
  • 4.0.bn1.gamma
  • 4.0.bn1.beta
  • 4.0.bn1.running_mean
  • 4.0.bn1.running_var
  • 4.0.bn2.gamma
  • 4.0.bn2.beta
  • 4.0.bn2.running_mean
  • 4.0.bn2.running_var
  • 4.1.conv1.weight
  • 4.1.conv1.bias
  • 4.1.conv2.weight
  • 4.1.conv2.bias
  • 4.1.bn1.gamma
  • 4.1.bn1.beta
  • 4.1.bn1.running_mean
  • 4.1.bn1.running_var
  • 4.1.bn2.gamma
  • 4.1.bn2.beta
  • 4.1.bn2.running_mean
  • 4.1.bn2.running_var
  • 4.2.conv1.weight
  • 4.2.conv1.bias
  • 4.2.conv2.weight
  • 4.2.conv2.bias
  • 4.2.bn1.gamma
  • 4.2.bn1.beta
  • 4.2.bn1.running_mean
  • 4.2.bn1.running_var
  • 4.2.bn2.gamma
  • 4.2.bn2.beta
  • 4.2.bn2.running_mean
  • 4.2.bn2.running_var
  • 4.3.conv1.weight
  • 4.3.conv1.bias
  • 4.3.conv2.weight
  • 4.3.conv2.bias
  • 4.3.bn1.gamma
  • 4.3.bn1.beta
  • 4.3.bn1.running_mean
  • 4.3.bn1.running_var
  • 4.3.bn2.gamma
  • 4.3.bn2.beta
  • 4.3.bn2.running_mean
  • 4.3.bn2.running_var
  • 5.0.conv1.weight
  • 5.0.conv1.bias
  • 5.0.conv2.weight
  • 5.0.conv2.bias
  • 5.0.conv3.weight
  • 5.0.conv3.bias
  • 5.0.bn1.gamma
  • 5.0.bn1.beta
  • 5.0.bn1.running_mean
  • 5.0.bn1.running_var
  • 5.0.bn2.gamma
  • 5.0.bn2.beta
  • 5.0.bn2.running_mean
  • 5.0.bn2.running_var
  • 5.1.conv1.weight
  • 5.1.conv1.bias
  • 5.1.conv2.weight
  • 5.1.conv2.bias
  • 5.1.bn1.gamma
  • 5.1.bn1.beta
  • 5.1.bn1.running_mean
  • 5.1.bn1.running_var
  • 5.1.bn2.gamma
  • 5.1.bn2.beta
  • 5.1.bn2.running_mean
  • 5.1.bn2.running_var
  • 5.2.conv1.weight
  • 5.2.conv1.bias
  • 5.2.conv2.weight
  • 5.2.conv2.bias
  • 5.2.bn1.gamma
  • 5.2.bn1.beta
  • 5.2.bn1.running_mean
  • 5.2.bn1.running_var
  • 5.2.bn2.gamma
  • 5.2.bn2.beta
  • 5.2.bn2.running_mean
  • 5.2.bn2.running_var
  • 5.3.conv1.weight
  • 5.3.conv1.bias
  • 5.3.conv2.weight
  • 5.3.conv2.bias
  • 5.3.bn1.gamma
  • 5.3.bn1.beta
  • 5.3.bn1.running_mean
  • 5.3.bn1.running_var
  • 5.3.bn2.gamma
  • 5.3.bn2.beta
  • 5.3.bn2.running_mean
  • 5.3.bn2.running_var
  • 6.0.conv1.weight
  • 6.0.conv1.bias
  • 6.0.conv2.weight
  • 6.0.conv2.bias
  • 6.0.conv3.weight
  • 6.0.conv3.bias
  • 6.0.bn1.gamma
  • 6.0.bn1.beta
  • 6.0.bn1.running_mean
  • 6.0.bn1.running_var
  • 6.0.bn2.gamma
  • 6.0.bn2.beta
  • 6.0.bn2.running_mean
  • 6.0.bn2.running_var
  • 6.1.conv1.weight
  • 6.1.conv1.bias
  • 6.1.conv2.weight
  • 6.1.conv2.bias
  • 6.1.bn1.gamma
  • 6.1.bn1.beta
  • 6.1.bn1.running_mean
  • 6.1.bn1.running_var
  • 6.1.bn2.gamma
  • 6.1.bn2.beta
  • 6.1.bn2.running_mean
  • 6.1.bn2.running_var
  • 6.2.conv1.weight
  • 6.2.conv1.bias
  • 6.2.conv2.weight
  • 6.2.conv2.bias
  • 6.2.bn1.gamma
  • 6.2.bn1.beta
  • 6.2.bn1.running_mean
  • 6.2.bn1.running_var
  • 6.2.bn2.gamma
  • 6.2.bn2.beta
  • 6.2.bn2.running_mean
  • 6.2.bn2.running_var
  • 6.3.conv1.weight
  • 6.3.conv1.bias
  • 6.3.conv2.weight
  • 6.3.conv2.bias
  • 6.3.bn1.gamma
  • 6.3.bn1.beta
  • 6.3.bn1.running_mean
  • 6.3.bn1.running_var
  • 6.3.bn2.gamma
  • 6.3.bn2.beta
  • 6.3.bn2.running_mean
  • 6.3.bn2.running_var
  • 7.0.conv1.weight
  • 7.0.conv1.bias
  • 7.0.conv2.weight
  • 7.0.conv2.bias
  • 7.0.conv3.weight
  • 7.0.conv3.bias
  • 7.0.bn1.gamma
  • 7.0.bn1.beta
  • 7.0.bn1.running_mean
  • 7.0.bn1.running_var
  • 7.0.bn2.gamma
  • 7.0.bn2.beta
  • 7.0.bn2.running_mean
  • 7.0.bn2.running_var
  • 7.1.conv1.weight
  • 7.1.conv1.bias
  • 7.1.conv2.weight
  • 7.1.conv2.bias
  • 7.1.bn1.gamma
  • 7.1.bn1.beta
  • 7.1.bn1.running_mean
  • 7.1.bn1.running_var
  • 7.1.bn2.gamma
  • 7.1.bn2.beta
  • 7.1.bn2.running_mean
  • 7.1.bn2.running_var
  • 7.2.conv1.weight
  • 7.2.conv1.bias
  • 7.2.conv2.weight
  • 7.2.conv2.bias
  • 7.2.bn1.gamma
  • 7.2.bn1.beta
  • 7.2.bn1.running_mean
  • 7.2.bn1.running_var
  • 7.2.bn2.gamma
  • 7.2.bn2.beta
  • 7.2.bn2.running_mean
  • 7.2.bn2.running_var
  • 7.3.conv1.weight
  • 7.3.conv1.bias
  • 7.3.conv2.weight
  • 7.3.conv2.bias
  • 7.3.bn1.gamma
  • 7.3.bn1.beta
  • 7.3.bn1.running_mean
  • 7.3.bn1.running_var
  • 7.3.bn2.gamma
  • 7.3.bn2.beta
  • 7.3.bn2.running_mean
  • 7.3.bn2.running_var
  • 9.weight
  • 9.bias
  • 10.gamma
  • 10.beta
  • 10.running_mean
  • 10.running_var
  • 13.weight
  • 13.bias

從字典的key可以看出,這個key的組成

  • 如果是sequence,則用sequence序號代表該layer
  • 如果是自定義的繼承block類的對象,則使用自己定義的layer的名字
  • 最後一個元素是block類中_reg_params這個字典成員變量中,參數的key,

_reg_params中以字典的形式保存layer對應的參數,如conv2d的_reg_params爲:

conv2d._reg_params={'weight':NDArray,'bias':NDArray}

 

block.save_parameters()

下面看block.save_parameters()這個函數如何把block對象的參數保存成上面的樣子

#block的成員函數,用遞歸的方式收集block對象所有的參數
#block對象可能是多層定義的,因此這裏使用了基於DFS的搜索方法

def _collect_params_with_prefix(self, prefix=''):
    if prefix:
        prefix += '.'
    #添加該block的參數
    ret = {prefix + key : val for key, val in self._reg_params.items()}
    #添加該block下的子block的參數
    for name, child in self._children.items():
        #遞歸,輸入傳前綴,前綴是當前block的輸入前綴+該子block的key
        #如果是sequence,key是一個數字,代表該block在sequence中的位置0,1,2,3……
        #如果是自定義的block,則是自定義的名稱
        #字典的update操作等於拼接兩個字典
        ret.update(child._collect_params_with_prefix(prefix + name))
    return ret

def save_parameters(self, filename):
    #得到列出所有參數的字典
    params = self._collect_params_with_prefix()
    #這一步應該是轉化成cpu下的NDArray
    arg_dict = {key : val._reduce() for key, val in params.items()}
    #ndarray類的保存字典函數
    ndarray.save(filename, arg_dict)

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章