[轉]Mxnet實現手寫數字mnist數據集訓練、預測指南

環境:

Anaconda3(64-bit),安裝mxnet1.3.1,opencv_python-3.4.5.20-cp36-cp36m-win_amd64.whl(可選)

訓練源碼:

# -*- coding: utf-8 -*-
"""
Created on Fri Jul 19 16:30:15 2019

@author: houwenbin
"""

import numpy as np
import mxnet as mx
import logging

logging.getLogger().setLevel(logging.DEBUG)

batch_size = 100
mnist = mx.test_utils.get_mnist()
train_iter = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'], batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)

data = mx.sym.var('data') 
# first conv layer
conv1= mx.sym.Convolution(data=data, kernel=(5,5), num_filter=20)
tanh1= mx.sym.Activation(data=conv1, act_type="tanh")
pool1= mx.sym.Pooling(data=tanh1, pool_type="max", kernel=(2,2), stride=(2,2))
# second conv layer
conv2= mx.sym.Convolution(data=pool1, kernel=(5,5), num_filter=50)
tanh2= mx.sym.Activation(data=conv2, act_type="tanh")
pool2= mx.sym.Pooling(data=tanh2, pool_type="max", kernel=(2,2), stride=(2,2))
# first fullc layer
flatten= mx.sym.Flatten(data=pool2)
fc1= mx.symbol.FullyConnected(data=flatten, num_hidden=500)
tanh3= mx.sym.Activation(data=fc1, act_type="tanh")
# second fullc
fc2= mx.sym.FullyConnected(data=tanh3, num_hidden=10)
# softmax loss
lenet= mx.sym.SoftmaxOutput(data=fc2, name='softmax')

# create a trainable module on GPU 0
lenet_model = mx.mod.Module(
                symbol=lenet, 
                context=mx.cpu())

# train with the same
lenet_model.fit(train_iter,
                eval_data=val_iter,
                optimizer='sgd',
                optimizer_params={'learning_rate':0.1},
                eval_metric='acc',
                batch_end_callback = mx.callback.Speedometer(batch_size, 100),
                num_epoch=10)

# save model params
#lenet_model.save_params("lenet_10.params");
#
lenet_model.save_checkpoint("lenet", 10, False);


預測源碼:

# -*- coding: utf-8 -*-
"""
Created on Fri Jul 19 20:17:26 2019

@author: houwenbin
"""

import time
import mxnet as mx
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

#
prefix = 'lenet'
iteration = 10
img_name = './digit_8.jpg'
synsets = [0,1,2,3,4,5,6,7,8,9]

# imagenet 圖像預處理
def load_image(img_name):
        
    #PIL
    #相關:scipy.misc.imread, scipy.ndimage.imread
    #misc.imread 提供可選參數mode,但本質上是調用PIL,具體的模式可以去看srccode或者document
    img = Image.open(img_name)
    if img is None:
        return None
    
    img = img.resize((28,28))
    img = np.array(img.convert('L'),'f') #讀取圖片,灰度化,轉換爲數組,L = 0.299R + 0.587G + 0.114B。'f'爲float類型
    #統一使用plt進行顯示,不管是plt還是cv2.imshow,在python中只認numpy.array,但是由於cv2.imread 的圖片是BGR,cv2.imshow 時相應的換通道顯示
    print(img.shape)
    plt.imshow(img)
    plt.show()
    #
    img = img.reshape(1,1,28,28).astype(np.float32)/255
    return img


time0 = time.time()

# 加載 mxnet symbol
sym, arg, aux = mx.model.load_checkpoint(prefix, iteration)
# 重建模型
mod = mx.mod.Module(symbol=sym, context=mx.cpu(), label_names=None)
mod.bind(for_training=False, data_shapes=[('data',(1,1,28,28))], label_shapes=mod._label_shapes) # 爲輸入數據分配內存
mod.set_params(arg, aux, allow_missing=True) # 加載模型參數

#
time1 = time.time()
print("模型加載和重建時間:{0}".format(time1 - time0))
#
#加載圖片
img = load_image(img_name)
if img is None:
    exit()

print(img.shape)
#
time0 = time.time()
#
# define a simple data batch
from collections import namedtuple
Batch = namedtuple('Batch', ['data'])
#
# compute the predict probabilities
mod.forward(Batch([mx.nd.array(img)])) # img{NDArray 1x1x28x28}做簡單的inference
#
time1 = time.time()
print("前向預測時間:{0}".format(time1 - time0))

#輸出Top-5預測結果
print(mod.get_outputs())
prob = mod.get_outputs()[0].asnumpy() #取出結果
print("-------result-------", prob, prob.shape)
prob = np.squeeze(prob)
print("-------squeeze result-------", prob, prob.shape)
print("-------sorted prob--------", np.sort(prob)) # 從小到大排列
print("-------arg sorted prob--------", np.argsort(prob))
a = np.argsort(prob)[::-1] # 得到分類網絡分類置信度的從大到小的結果
print("------top sorted index-------", a, a.shape)
if a is not None:
    for i in a[0:5]:
        print('probability=%f, class=%s' %(prob[i], synsets[i]))


數據準備:

使用畫圖工具,繪製一個128x128的黑色背景,用橡皮擦擦除待檢測數字即可(本文是digit_8.jpg)。

運行結果:

in[1]:runfile('C:/Users/houwenbin/Documents/PythonProject/test_mnist.py', wdir='C:/Users/houwenbin/Documents/PythonProject')

模型加載和重建時間:0.0060160160064697266

(28, 28)

(1, 1, 28, 28)

前向預測時間:0.0010042190551757812

[

[[3.0556594e-06 1.3175709e-06 4.1811345e-06 1.1044953e-08 9.9990916e-01

  4.0004899e-10 3.0342795e-05 1.8727254e-05 2.3288235e-06 3.0720061e-05]]

<NDArray 1x10 @cpu(0)>]

-------result------- [[3.0556594e-06 1.3175709e-06 4.1811345e-06 1.1044953e-08 9.9990916e-01

  4.0004899e-10 3.0342795e-05 1.8727254e-05 2.3288235e-06 3.0720061e-05]] (1, 10)

-------squeeze result------- [3.0556594e-06 1.3175709e-06 4.1811345e-06 1.1044953e-08 9.9990916e-01

 4.0004899e-10 3.0342795e-05 1.8727254e-05 2.3288235e-06 3.0720061e-05] (10,)

-------sorted prob-------- [4.0004899e-10 1.1044953e-08 1.3175709e-06 2.3288235e-06 3.0556594e-06

 4.1811345e-06 1.8727254e-05 3.0342795e-05 3.0720061e-05 9.9990916e-01]

-------arg sorted prob-------- [5 3 1 8 0 2 7 6 9 4]

------top sorted index------- [4 9 6 7 2 0 8 1 3 5] (10,)

probability=0.999909, class=4

probability=0.000031, class=9

probability=0.000030, class=6

probability=0.000019, class=7

probability=0.000004, class=2

in[2]:

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章