mxnet学习(5):模型参数

模型参数选择

class mxnet.gluon.Block(prefix = None, params = None)

其中collect_params()是该类下面的一个成员函数,而该类是所有的网络和模型的基类。该函数返回一个参数字典包含了这个Block和其所有孩子的参数,同样可以有选择性的返回部分参数,使用正则表达式来选择。eg:

model.collect_params('conv1_weight|conv1_bias|fc_weight|fc_bias')
model.collect_params('.*weight|.*bias')

模型参数初始化

initialize(init=, ctx = None, verbose = False, force_reinit = False)
函数是对Block及其孩子进行初始化。

参数

  • init (Initializer) – Global default Initializer to be used when Parameter.init() is None. Otherwise, Parameter.init() takes precedence.
  • ctx (Context or list of Context) – Keeps a copy of Parameters on one or many context(s).
  • verbose (bool, default False) – Whether to verbosely print out details on initialization.
  • force_reinit (bool, default False) – Whether to force re-initialization if parameter is already initialized.

eg:

net.collect_params().initialize(init = init.Xavier, ctx = gpu(0))
fc_layer.initialize()

模型参数保存

file_name = 'net.params'
net.save_parameters(file_name)

Note: Block.collect_params().save() is not a recommended way to save parameters of a Gluon network if you plan to load the parameters back into a Gluon network using Block.load_parameters().

模型参数加载

load_parameters(filename, ctx=None, allow_missing=False, ignore_extra=False)

Load parameters from file previously saved by save_parameters.

Parameters:

  • filename (str) – Path to parameter file.
  • ctx (Context or list of Context, default cpu()) – Context(s) to initialize loaded parameters on.
  • allow_missing (bool, default False) – Whether to silently skip loading parameters not represents in the file.
  • ignore_extra (bool, default False) – Whether to silently ignore parameters from the file that are not present in this Block.
new_net = build_lenet(gluon.nn.Sequential())
new_net.load_parameters(file_name, ctx=ctx)

Note that to do this, we need the definition of the network as Python code. If we want to recreate this network on a different machine using the saved weights, we need the same Python code (build_lenet) that created the network to create the new_net object shown above. This means Python code needs to be copied over to any machine where we want to run this network.

gluon模型保存和加载参考
https://mxnet.incubator.apache.org/versions/master/tutorials/gluon/save_load_params.html

模型参数和网络结构保存

此前的模型参数保存和加载方式存在一个限制,即网络结构的代码必须要copy一份,然后建立这个网络结构,才能正确地加载参数,但是这样的加载方式并不方便。

而Hybrid model可以使用export函数序列化生成JSON文件,一旦保存以后,模型可以使用其他语言或者在其他接口环境下使用。

创建一个hybrid网络

net = build_lenet(gluon.nn.HybridSequential())
net.hybridize()
train_model(net)

当我们训练完成之后,可以使用export函数保存模型和参数,模型结构保存为.json文件,参数保存为.params文件

net.export('lenet', epoch = 1)

exportin this case creates lenet-symbol.jsonand lenet-0001.paramsin the current directory.

模型参数和结构加载

上述模型的参数和结构在保存之后可以使用imports函数进行加载. eg:

deserialized_net = gluon.nn.SymbolBlock.imports("lenet-symbol.json", ['data'], "lenet-0001.params", ctx=ctx)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章