定義一個resnet-18網絡
import gluonbook as gb
from mxnet.gluon import Trainer,data as gdata, nn
from mxnet import init, nd
import os
import sys
class Residual(nn.Block): # 本類已保存在 gluonbook 包中方便以後使用。
def __init__(self, num_channels, use_1x1conv=False, strides=1, **kwargs):
super(Residual, self).__init__(**kwargs)
self.conv1 = nn.Conv2D(num_channels, kernel_size=3, padding=1,
strides=strides)
self.conv2 = nn.Conv2D(num_channels, kernel_size=3, padding=1)
if use_1x1conv:
self.conv3 = nn.Conv2D(num_channels, kernel_size=1,
strides=strides)
else:
self.conv3 = None
self.bn1 = nn.BatchNorm()
self.bn2 = nn.BatchNorm()
def forward(self, X):
Y = nd.relu(self.bn1(self.conv1(X)))
Y = self.bn2(self.conv2(Y))
if self.conv3:
X = self.conv3(X)
return nd.relu(Y + X)
def resnet_block(num_channels, num_residuals, first_block =False):
blk = nn.Sequential()
for i in range(num_residuals):
if i==0 and not first_block:
blk.add(Residual(num_channels,use_1x1conv=True,strides=2))
else:
blk.add(Residual(num_channels))
return blk
net = nn.Sequential()
net.add(nn.Conv2D(64,kernel_size=11,padding=3,strides=2),
nn.BatchNorm(),
nn.Activation('relu'),
nn.MaxPool2D(pool_size=3,strides=2,padding=1),
resnet_block(64,4,first_block=True),
resnet_block(128,4),
resnet_block(256,4),
resnet_block(512,4),
nn.GlobalAvgPool2D(),
nn.Dense(1024),
nn.BatchNorm(),
nn.Activation('relu'),
nn.Dropout(0.4),
nn.Dense(10))
隨機初始化網絡並且保存數據,使用block類自帶的save_parameters() 成員函數:
net.initialize(force_reinit=True,ctx=ctx,init=init.Xavier())
new_filename = 'tmp.params'
net.save_parameters(new_filename)
然後就是讀進來數據,分析保存的結構
#load params for analyzation
params = nd.load('tmp.params')
#params is a dict
print(isinstance(params,dict))
#print dict members'names
for key in params:
print(key)
nd.load的結果是一個字典,字典的keys的打印結果如下:
- 0.weight
- 0.bias
- 1.gamma
- 1.beta
- 1.running_mean
- 1.running_var
- 4.0.conv1.weight
- 4.0.conv1.bias
- 4.0.conv2.weight
- 4.0.conv2.bias
- 4.0.bn1.gamma
- 4.0.bn1.beta
- 4.0.bn1.running_mean
- 4.0.bn1.running_var
- 4.0.bn2.gamma
- 4.0.bn2.beta
- 4.0.bn2.running_mean
- 4.0.bn2.running_var
- 4.1.conv1.weight
- 4.1.conv1.bias
- 4.1.conv2.weight
- 4.1.conv2.bias
- 4.1.bn1.gamma
- 4.1.bn1.beta
- 4.1.bn1.running_mean
- 4.1.bn1.running_var
- 4.1.bn2.gamma
- 4.1.bn2.beta
- 4.1.bn2.running_mean
- 4.1.bn2.running_var
- 4.2.conv1.weight
- 4.2.conv1.bias
- 4.2.conv2.weight
- 4.2.conv2.bias
- 4.2.bn1.gamma
- 4.2.bn1.beta
- 4.2.bn1.running_mean
- 4.2.bn1.running_var
- 4.2.bn2.gamma
- 4.2.bn2.beta
- 4.2.bn2.running_mean
- 4.2.bn2.running_var
- 4.3.conv1.weight
- 4.3.conv1.bias
- 4.3.conv2.weight
- 4.3.conv2.bias
- 4.3.bn1.gamma
- 4.3.bn1.beta
- 4.3.bn1.running_mean
- 4.3.bn1.running_var
- 4.3.bn2.gamma
- 4.3.bn2.beta
- 4.3.bn2.running_mean
- 4.3.bn2.running_var
- 5.0.conv1.weight
- 5.0.conv1.bias
- 5.0.conv2.weight
- 5.0.conv2.bias
- 5.0.conv3.weight
- 5.0.conv3.bias
- 5.0.bn1.gamma
- 5.0.bn1.beta
- 5.0.bn1.running_mean
- 5.0.bn1.running_var
- 5.0.bn2.gamma
- 5.0.bn2.beta
- 5.0.bn2.running_mean
- 5.0.bn2.running_var
- 5.1.conv1.weight
- 5.1.conv1.bias
- 5.1.conv2.weight
- 5.1.conv2.bias
- 5.1.bn1.gamma
- 5.1.bn1.beta
- 5.1.bn1.running_mean
- 5.1.bn1.running_var
- 5.1.bn2.gamma
- 5.1.bn2.beta
- 5.1.bn2.running_mean
- 5.1.bn2.running_var
- 5.2.conv1.weight
- 5.2.conv1.bias
- 5.2.conv2.weight
- 5.2.conv2.bias
- 5.2.bn1.gamma
- 5.2.bn1.beta
- 5.2.bn1.running_mean
- 5.2.bn1.running_var
- 5.2.bn2.gamma
- 5.2.bn2.beta
- 5.2.bn2.running_mean
- 5.2.bn2.running_var
- 5.3.conv1.weight
- 5.3.conv1.bias
- 5.3.conv2.weight
- 5.3.conv2.bias
- 5.3.bn1.gamma
- 5.3.bn1.beta
- 5.3.bn1.running_mean
- 5.3.bn1.running_var
- 5.3.bn2.gamma
- 5.3.bn2.beta
- 5.3.bn2.running_mean
- 5.3.bn2.running_var
- 6.0.conv1.weight
- 6.0.conv1.bias
- 6.0.conv2.weight
- 6.0.conv2.bias
- 6.0.conv3.weight
- 6.0.conv3.bias
- 6.0.bn1.gamma
- 6.0.bn1.beta
- 6.0.bn1.running_mean
- 6.0.bn1.running_var
- 6.0.bn2.gamma
- 6.0.bn2.beta
- 6.0.bn2.running_mean
- 6.0.bn2.running_var
- 6.1.conv1.weight
- 6.1.conv1.bias
- 6.1.conv2.weight
- 6.1.conv2.bias
- 6.1.bn1.gamma
- 6.1.bn1.beta
- 6.1.bn1.running_mean
- 6.1.bn1.running_var
- 6.1.bn2.gamma
- 6.1.bn2.beta
- 6.1.bn2.running_mean
- 6.1.bn2.running_var
- 6.2.conv1.weight
- 6.2.conv1.bias
- 6.2.conv2.weight
- 6.2.conv2.bias
- 6.2.bn1.gamma
- 6.2.bn1.beta
- 6.2.bn1.running_mean
- 6.2.bn1.running_var
- 6.2.bn2.gamma
- 6.2.bn2.beta
- 6.2.bn2.running_mean
- 6.2.bn2.running_var
- 6.3.conv1.weight
- 6.3.conv1.bias
- 6.3.conv2.weight
- 6.3.conv2.bias
- 6.3.bn1.gamma
- 6.3.bn1.beta
- 6.3.bn1.running_mean
- 6.3.bn1.running_var
- 6.3.bn2.gamma
- 6.3.bn2.beta
- 6.3.bn2.running_mean
- 6.3.bn2.running_var
- 7.0.conv1.weight
- 7.0.conv1.bias
- 7.0.conv2.weight
- 7.0.conv2.bias
- 7.0.conv3.weight
- 7.0.conv3.bias
- 7.0.bn1.gamma
- 7.0.bn1.beta
- 7.0.bn1.running_mean
- 7.0.bn1.running_var
- 7.0.bn2.gamma
- 7.0.bn2.beta
- 7.0.bn2.running_mean
- 7.0.bn2.running_var
- 7.1.conv1.weight
- 7.1.conv1.bias
- 7.1.conv2.weight
- 7.1.conv2.bias
- 7.1.bn1.gamma
- 7.1.bn1.beta
- 7.1.bn1.running_mean
- 7.1.bn1.running_var
- 7.1.bn2.gamma
- 7.1.bn2.beta
- 7.1.bn2.running_mean
- 7.1.bn2.running_var
- 7.2.conv1.weight
- 7.2.conv1.bias
- 7.2.conv2.weight
- 7.2.conv2.bias
- 7.2.bn1.gamma
- 7.2.bn1.beta
- 7.2.bn1.running_mean
- 7.2.bn1.running_var
- 7.2.bn2.gamma
- 7.2.bn2.beta
- 7.2.bn2.running_mean
- 7.2.bn2.running_var
- 7.3.conv1.weight
- 7.3.conv1.bias
- 7.3.conv2.weight
- 7.3.conv2.bias
- 7.3.bn1.gamma
- 7.3.bn1.beta
- 7.3.bn1.running_mean
- 7.3.bn1.running_var
- 7.3.bn2.gamma
- 7.3.bn2.beta
- 7.3.bn2.running_mean
- 7.3.bn2.running_var
- 9.weight
- 9.bias
- 10.gamma
- 10.beta
- 10.running_mean
- 10.running_var
- 13.weight
- 13.bias
從字典的key可以看出,這個key的組成
- 如果是sequence,則用sequence序號代表該layer
- 如果是自定義的繼承block類的對象,則使用自己定義的layer的名字
- 最後一個元素是block類中_reg_params這個字典成員變量中,參數的key,
_reg_params中以字典的形式保存layer對應的參數,如conv2d的_reg_params爲:
conv2d._reg_params={'weight':NDArray,'bias':NDArray}
block.save_parameters()
下面看block.save_parameters()這個函數如何把block對象的參數保存成上面的樣子
#block的成員函數,用遞歸的方式收集block對象所有的參數
#block對象可能是多層定義的,因此這裏使用了基於DFS的搜索方法
def _collect_params_with_prefix(self, prefix=''):
if prefix:
prefix += '.'
#添加該block的參數
ret = {prefix + key : val for key, val in self._reg_params.items()}
#添加該block下的子block的參數
for name, child in self._children.items():
#遞歸,輸入傳前綴,前綴是當前block的輸入前綴+該子block的key
#如果是sequence,key是一個數字,代表該block在sequence中的位置0,1,2,3……
#如果是自定義的block,則是自定義的名稱
#字典的update操作等於拼接兩個字典
ret.update(child._collect_params_with_prefix(prefix + name))
return ret
def save_parameters(self, filename):
#得到列出所有參數的字典
params = self._collect_params_with_prefix()
#這一步應該是轉化成cpu下的NDArray
arg_dict = {key : val._reduce() for key, val in params.items()}
#ndarray類的保存字典函數
ndarray.save(filename, arg_dict)