mxnet學習(5):模型參數

模型參數選擇

class mxnet.gluon.Block(prefix = None, params = None)

其中collect_params()是該類下面的一個成員函數,而該類是所有的網絡和模型的基類。該函數返回一個參數字典包含了這個Block和其所有孩子的參數,同樣可以有選擇性的返回部分參數,使用正則表達式來選擇。eg:

model.collect_params('conv1_weight|conv1_bias|fc_weight|fc_bias')
model.collect_params('.*weight|.*bias')

模型參數初始化

initialize(init=, ctx = None, verbose = False, force_reinit = False)
函數是對Block及其孩子進行初始化。

參數

  • init (Initializer) – Global default Initializer to be used when Parameter.init() is None. Otherwise, Parameter.init() takes precedence.
  • ctx (Context or list of Context) – Keeps a copy of Parameters on one or many context(s).
  • verbose (bool, default False) – Whether to verbosely print out details on initialization.
  • force_reinit (bool, default False) – Whether to force re-initialization if parameter is already initialized.

eg:

net.collect_params().initialize(init = init.Xavier, ctx = gpu(0))
fc_layer.initialize()

模型參數保存

file_name = 'net.params'
net.save_parameters(file_name)

Note: Block.collect_params().save() is not a recommended way to save parameters of a Gluon network if you plan to load the parameters back into a Gluon network using Block.load_parameters().

模型參數加載

load_parameters(filename, ctx=None, allow_missing=False, ignore_extra=False)

Load parameters from file previously saved by save_parameters.

Parameters:

  • filename (str) – Path to parameter file.
  • ctx (Context or list of Context, default cpu()) – Context(s) to initialize loaded parameters on.
  • allow_missing (bool, default False) – Whether to silently skip loading parameters not represents in the file.
  • ignore_extra (bool, default False) – Whether to silently ignore parameters from the file that are not present in this Block.
new_net = build_lenet(gluon.nn.Sequential())
new_net.load_parameters(file_name, ctx=ctx)

Note that to do this, we need the definition of the network as Python code. If we want to recreate this network on a different machine using the saved weights, we need the same Python code (build_lenet) that created the network to create the new_net object shown above. This means Python code needs to be copied over to any machine where we want to run this network.

gluon模型保存和加載參考
https://mxnet.incubator.apache.org/versions/master/tutorials/gluon/save_load_params.html

模型參數和網絡結構保存

此前的模型參數保存和加載方式存在一個限制,即網絡結構的代碼必須要copy一份,然後建立這個網絡結構,才能正確地加載參數,但是這樣的加載方式並不方便。

而Hybrid model可以使用export函數序列化生成JSON文件,一旦保存以後,模型可以使用其他語言或者在其他接口環境下使用。

創建一個hybrid網絡

net = build_lenet(gluon.nn.HybridSequential())
net.hybridize()
train_model(net)

當我們訓練完成之後,可以使用export函數保存模型和參數,模型結構保存爲.json文件,參數保存爲.params文件

net.export('lenet', epoch = 1)

exportin this case creates lenet-symbol.jsonand lenet-0001.paramsin the current directory.

模型參數和結構加載

上述模型的參數和結構在保存之後可以使用imports函數進行加載. eg:

deserialized_net = gluon.nn.SymbolBlock.imports("lenet-symbol.json", ['data'], "lenet-0001.params", ctx=ctx)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章