mxnet 手寫數字識別的例子

前面寫了一些關於mxnetAPI的,現在給出一個github上mxnet的一個例子

import os,  sys
from utils import get_data

import mxnet as mx
import numpy as np
import logging
# 創建計算圖
data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64)
act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10)
# print(fc3)  這時候只是一個符號
softmax = mx.symbol.SoftmaxOutput(fc3, name = 'softmax')

n_epoch = 2
batch_size = 100
# 加載數據
basedir = os.path.dirname(__file__)
get_data.get_mnist(os.path.join(basedir, "data"))

train_dataiter = mx.io.MNISTIter(
        image=os.path.join(basedir, "data", "train-images-idx3-ubyte"),
        label=os.path.join(basedir, "data", "train-labels-idx1-ubyte"),
        data_shape=(784,),
        batch_size=batch_size, shuffle=True, flat=True, silent=False, seed=10)

val_dataiter = mx.io.MNISTIter(
        image=os.path.join(basedir, "data", "t10k-images-idx3-ubyte"),
        label=os.path.join(basedir, "data", "t10k-labels-idx1-ubyte"),
        data_shape=(784,),
        batch_size=batch_size, shuffle=True, flat=True, silent=False)

################################################################################
# Intermediate-level API
################################################################################
#傳入symbol,創建模型
mod = mx.mod.Module(softmax)
# bind
mod.bind(data_shapes=train_dataiter.provide_data, label_shapes=train_dataiter.provide_label)
# 模型初始化參數
mod.init_params()
# 模型的優化器和學習率
mod.init_optimizer(optimizer_params={'learning_rate':0.01, 'momentum': 0.9})
# 建立模型評價的選擇
metric = mx.metric.create('acc')

for i_epoch in range(n_epoch):
    for i_iter, batch in enumerate(train_dataiter):
        # 前向計算
        mod.forward(batch)
        # 更新
        mod.update_metric(metric, batch.label)
        # 後向
        mod.backward()
        # 再更新
        mod.update()

    for name, val in metric.get_name_value():
        print('epoch %03d: %s=%f' % (i_epoch, name, val))
    # 每訓練一次,參數都會更新,所以評價也得更新,所以要先重置
    metric.reset()
    train_dataiter.reset()

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章