前面寫了一些關於mxnetAPI的,現在給出一個github上mxnet的一個例子
import os, sys from utils import get_data import mxnet as mx import numpy as np import logging # 創建計算圖 data = mx.symbol.Variable('data') fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128) act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu") fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64) act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu") fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10) # print(fc3) 這時候只是一個符號 softmax = mx.symbol.SoftmaxOutput(fc3, name = 'softmax') n_epoch = 2 batch_size = 100 # 加載數據 basedir = os.path.dirname(__file__) get_data.get_mnist(os.path.join(basedir, "data")) train_dataiter = mx.io.MNISTIter( image=os.path.join(basedir, "data", "train-images-idx3-ubyte"), label=os.path.join(basedir, "data", "train-labels-idx1-ubyte"), data_shape=(784,), batch_size=batch_size, shuffle=True, flat=True, silent=False, seed=10) val_dataiter = mx.io.MNISTIter( image=os.path.join(basedir, "data", "t10k-images-idx3-ubyte"), label=os.path.join(basedir, "data", "t10k-labels-idx1-ubyte"), data_shape=(784,), batch_size=batch_size, shuffle=True, flat=True, silent=False) ################################################################################ # Intermediate-level API ################################################################################ #傳入symbol,創建模型 mod = mx.mod.Module(softmax) # bind mod.bind(data_shapes=train_dataiter.provide_data, label_shapes=train_dataiter.provide_label) # 模型初始化參數 mod.init_params() # 模型的優化器和學習率 mod.init_optimizer(optimizer_params={'learning_rate':0.01, 'momentum': 0.9}) # 建立模型評價的選擇 metric = mx.metric.create('acc') for i_epoch in range(n_epoch): for i_iter, batch in enumerate(train_dataiter): # 前向計算 mod.forward(batch) # 更新 mod.update_metric(metric, batch.label) # 後向 mod.backward() # 再更新 mod.update() for name, val in metric.get_name_value(): print('epoch %03d: %s=%f' % (i_epoch, name, val)) # 每訓練一次,參數都會更新,所以評價也得更新,所以要先重置 metric.reset() train_dataiter.reset()