模型參數選擇
class mxnet.gluon.Block(prefix = None, params = None)
其中collect_params()
是該類下面的一個成員函數,而該類是所有的網絡和模型的基類。該函數返回一個參數字典包含了這個Block和其所有孩子的參數,同樣可以有選擇性的返回部分參數,使用正則表達式來選擇。eg:
model.collect_params('conv1_weight|conv1_bias|fc_weight|fc_bias')
model.collect_params('.*weight|.*bias')
模型參數初始化
initialize(init=, ctx = None, verbose = False, force_reinit = False)
函數是對Block及其孩子進行初始化。
參數
- init (Initializer) – Global default Initializer to be used when Parameter.init() is None. Otherwise, Parameter.init() takes precedence.
- ctx (Context or list of Context) – Keeps a copy of Parameters on one or many context(s).
- verbose (bool, default False) – Whether to verbosely print out details on initialization.
- force_reinit (bool, default False) – Whether to force re-initialization if parameter is already initialized.
eg:
net.collect_params().initialize(init = init.Xavier, ctx = gpu(0))
fc_layer.initialize()
模型參數保存
file_name = 'net.params'
net.save_parameters(file_name)
Note: Block.collect_params().save() is not a recommended way to save parameters of a Gluon network if you plan to load the parameters back into a Gluon network using Block.load_parameters().
模型參數加載
load_parameters(filename, ctx=None, allow_missing=False, ignore_extra=False)
Load parameters from file previously saved by save_parameters.
Parameters:
- filename (str) – Path to parameter file.
- ctx (Context or list of Context, default cpu()) – Context(s) to initialize loaded parameters on.
- allow_missing (bool, default False) – Whether to silently skip loading parameters not represents in the file.
- ignore_extra (bool, default False) – Whether to silently ignore parameters from the file that are not present in this Block.
new_net = build_lenet(gluon.nn.Sequential())
new_net.load_parameters(file_name, ctx=ctx)
Note that to do this, we need the definition of the network as Python code. If we want to recreate this network on a different machine using the saved weights, we need the same Python code (build_lenet) that created the network to create the new_net object shown above. This means Python code needs to be copied over to any machine where we want to run this network.
gluon模型保存和加載參考
https://mxnet.incubator.apache.org/versions/master/tutorials/gluon/save_load_params.html
模型參數和網絡結構保存
此前的模型參數保存和加載方式存在一個限制,即網絡結構的代碼必須要copy一份,然後建立這個網絡結構,才能正確地加載參數,但是這樣的加載方式並不方便。
而Hybrid model可以使用export
函數序列化生成JSON文件,一旦保存以後,模型可以使用其他語言或者在其他接口環境下使用。
創建一個hybrid網絡
net = build_lenet(gluon.nn.HybridSequential())
net.hybridize()
train_model(net)
當我們訓練完成之後,可以使用export
函數保存模型和參數,模型結構保存爲.json
文件,參數保存爲.params
文件
net.export('lenet', epoch = 1)
export
in this case creates lenet-symbol.json
and lenet-0001.params
in the current directory.
模型參數和結構加載
上述模型的參數和結構在保存之後可以使用imports
函數進行加載. eg:
deserialized_net = gluon.nn.SymbolBlock.imports("lenet-symbol.json", ['data'], "lenet-0001.params", ctx=ctx)