MXNET學習筆記(一):Module類(1)

Module 是 mxnet 提供給用戶的一個高級封裝的類。有了它,我們可以很容易的來訓練模型。

Module 包含以下單元的一個 wraper

  • symbol : 用來表示網絡前向過程的 symbol
  • optimizer: 優化器,用來更新網絡。
  • exec_group: 用來執行 前向和反向計算。

所以 Module 可以幫助我們做

  • 前向計算,(由 exec_group 提供支持)
  • 反向計算,(由 exec_group 提供支持)
  • 更新網絡,(由 optimizer 提供支持)

一個 Demo

下面來看 MXNET 官網上提供的一個 Module 案例

第一部分:準備數據

import logging
logging.getLogger().setLevel(logging.INFO)
import mxnet as mx
import numpy as np

fname = mx.test_utils.download('http://archive.ics.uci.edu/ml/machine-learning-databases/letter-recognition/letter-recognition.data')
data = np.genfromtxt(fname, delimiter=',')[:,1:]
label = np.array([ord(l.split(',')[0])-ord('A') for l in open(fname, 'r')])

batch_size = 32
ntrain = int(data.shape[0]*0.8)
train_iter = mx.io.NDArrayIter(data[:ntrain, :], label[:ntrain], batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(data[ntrain:, :], label[ntrain:], batch_size)

第二部分:構建網絡

net = mx.sym.Variable('data')
net = mx.sym.FullyConnected(net, name='fc1', num_hidden=64)
net = mx.sym.Activation(net, name='relu1', act_type="relu")
net = mx.sym.FullyConnected(net, name='fc2', num_hidden=26)
net = mx.sym.SoftmaxOutput(net, name='softmax')
mx.viz.plot_network(net)

第三部分:創建Module

mod = mx.mod.Module(symbol=net,
                    context=mx.cpu(),
                    data_names=['data'],
                    label_names=['softmax_label'])

# 通過data_shapes 和 label_shapes 推斷其餘參數的 shape,然後給它們分配空間
mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
# 初始化模型的參數
mod.init_params(initializer=mx.init.Uniform(scale=.1))
# 初始化優化器,優化器用來更新模型
mod.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.1), ))
# use accuracy as the metric
metric = mx.metric.create('acc')
# train 5 epochs, i.e. going over the data iter one pass
for epoch in range(5):
    train_iter.reset()
    metric.reset()
    for batch in train_iter:
        mod.forward(batch, is_train=True)       # 前向計算
        mod.update_metric(metric, batch.label)  # accumulate prediction accuracy
        mod.backward()                          # 反向傳導
        mod.update()                            # 更新參數
    print('Epoch %d, Training %s' % (epoch, metric.get()))

關於 bind 的參數:

  • data_shapes : list of (str, tuple), str 是 數據 Symbol 的名字,tuple是 mini-batch 的形狀,所以一般參數是[('data', (64, 3, 224, 224))]
  • label_shapes: list of (str, tuple),str 是 標籤 Symbol 的名字,tuple是 mini-batch 標籤的形狀,一般 分類任務的 參數爲 [('softmax_label'),(64,)]
  • 爲什麼上面兩個參數都是 list 呢? 因爲可能某些網絡架構,不止一個 數據,不止一種 標籤。

關於 forward的參數

  • data_batch : 一個 mx.io.DataBatch-like 對象。只要一個對象,可以 .data返回 mini-batch 訓練數據, .label 返回相應的標籤,就可以作爲 data_batch 的實參 。
  • 關於 DataBatch對象:.data 返回的是 list of NDArray(網絡可能有多個輸入數據),.label 也一樣。
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章