[转]Mxnet实现手写数字mnist数据集训练、预测指南

环境:

Anaconda3(64-bit),安装mxnet1.3.1,opencv_python-3.4.5.20-cp36-cp36m-win_amd64.whl(可选)

训练源码:

# -*- coding: utf-8 -*-
"""
Created on Fri Jul 19 16:30:15 2019

@author: houwenbin
"""

import numpy as np
import mxnet as mx
import logging

logging.getLogger().setLevel(logging.DEBUG)

batch_size = 100
mnist = mx.test_utils.get_mnist()
train_iter = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'], batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)

data = mx.sym.var('data') 
# first conv layer
conv1= mx.sym.Convolution(data=data, kernel=(5,5), num_filter=20)
tanh1= mx.sym.Activation(data=conv1, act_type="tanh")
pool1= mx.sym.Pooling(data=tanh1, pool_type="max", kernel=(2,2), stride=(2,2))
# second conv layer
conv2= mx.sym.Convolution(data=pool1, kernel=(5,5), num_filter=50)
tanh2= mx.sym.Activation(data=conv2, act_type="tanh")
pool2= mx.sym.Pooling(data=tanh2, pool_type="max", kernel=(2,2), stride=(2,2))
# first fullc layer
flatten= mx.sym.Flatten(data=pool2)
fc1= mx.symbol.FullyConnected(data=flatten, num_hidden=500)
tanh3= mx.sym.Activation(data=fc1, act_type="tanh")
# second fullc
fc2= mx.sym.FullyConnected(data=tanh3, num_hidden=10)
# softmax loss
lenet= mx.sym.SoftmaxOutput(data=fc2, name='softmax')

# create a trainable module on GPU 0
lenet_model = mx.mod.Module(
                symbol=lenet, 
                context=mx.cpu())

# train with the same
lenet_model.fit(train_iter,
                eval_data=val_iter,
                optimizer='sgd',
                optimizer_params={'learning_rate':0.1},
                eval_metric='acc',
                batch_end_callback = mx.callback.Speedometer(batch_size, 100),
                num_epoch=10)

# save model params
#lenet_model.save_params("lenet_10.params");
#
lenet_model.save_checkpoint("lenet", 10, False);


预测源码:

# -*- coding: utf-8 -*-
"""
Created on Fri Jul 19 20:17:26 2019

@author: houwenbin
"""

import time
import mxnet as mx
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

#
prefix = 'lenet'
iteration = 10
img_name = './digit_8.jpg'
synsets = [0,1,2,3,4,5,6,7,8,9]

# imagenet 图像预处理
def load_image(img_name):
        
    #PIL
    #相关:scipy.misc.imread, scipy.ndimage.imread
    #misc.imread 提供可选参数mode,但本质上是调用PIL,具体的模式可以去看srccode或者document
    img = Image.open(img_name)
    if img is None:
        return None
    
    img = img.resize((28,28))
    img = np.array(img.convert('L'),'f') #读取图片,灰度化,转换为数组,L = 0.299R + 0.587G + 0.114B。'f'为float类型
    #统一使用plt进行显示,不管是plt还是cv2.imshow,在python中只认numpy.array,但是由于cv2.imread 的图片是BGR,cv2.imshow 时相应的换通道显示
    print(img.shape)
    plt.imshow(img)
    plt.show()
    #
    img = img.reshape(1,1,28,28).astype(np.float32)/255
    return img


time0 = time.time()

# 加载 mxnet symbol
sym, arg, aux = mx.model.load_checkpoint(prefix, iteration)
# 重建模型
mod = mx.mod.Module(symbol=sym, context=mx.cpu(), label_names=None)
mod.bind(for_training=False, data_shapes=[('data',(1,1,28,28))], label_shapes=mod._label_shapes) # 为输入数据分配内存
mod.set_params(arg, aux, allow_missing=True) # 加载模型参数

#
time1 = time.time()
print("模型加载和重建时间:{0}".format(time1 - time0))
#
#加载图片
img = load_image(img_name)
if img is None:
    exit()

print(img.shape)
#
time0 = time.time()
#
# define a simple data batch
from collections import namedtuple
Batch = namedtuple('Batch', ['data'])
#
# compute the predict probabilities
mod.forward(Batch([mx.nd.array(img)])) # img{NDArray 1x1x28x28}做简单的inference
#
time1 = time.time()
print("前向预测时间:{0}".format(time1 - time0))

#输出Top-5预测结果
print(mod.get_outputs())
prob = mod.get_outputs()[0].asnumpy() #取出结果
print("-------result-------", prob, prob.shape)
prob = np.squeeze(prob)
print("-------squeeze result-------", prob, prob.shape)
print("-------sorted prob--------", np.sort(prob)) # 从小到大排列
print("-------arg sorted prob--------", np.argsort(prob))
a = np.argsort(prob)[::-1] # 得到分类网络分类置信度的从大到小的结果
print("------top sorted index-------", a, a.shape)
if a is not None:
    for i in a[0:5]:
        print('probability=%f, class=%s' %(prob[i], synsets[i]))


数据准备:

使用画图工具,绘制一个128x128的黑色背景,用橡皮擦擦除待检测数字即可(本文是digit_8.jpg)。

运行结果:

in[1]:runfile('C:/Users/houwenbin/Documents/PythonProject/test_mnist.py', wdir='C:/Users/houwenbin/Documents/PythonProject')

模型加载和重建时间:0.0060160160064697266

(28, 28)

(1, 1, 28, 28)

前向预测时间:0.0010042190551757812

[

[[3.0556594e-06 1.3175709e-06 4.1811345e-06 1.1044953e-08 9.9990916e-01

  4.0004899e-10 3.0342795e-05 1.8727254e-05 2.3288235e-06 3.0720061e-05]]

<NDArray 1x10 @cpu(0)>]

-------result------- [[3.0556594e-06 1.3175709e-06 4.1811345e-06 1.1044953e-08 9.9990916e-01

  4.0004899e-10 3.0342795e-05 1.8727254e-05 2.3288235e-06 3.0720061e-05]] (1, 10)

-------squeeze result------- [3.0556594e-06 1.3175709e-06 4.1811345e-06 1.1044953e-08 9.9990916e-01

 4.0004899e-10 3.0342795e-05 1.8727254e-05 2.3288235e-06 3.0720061e-05] (10,)

-------sorted prob-------- [4.0004899e-10 1.1044953e-08 1.3175709e-06 2.3288235e-06 3.0556594e-06

 4.1811345e-06 1.8727254e-05 3.0342795e-05 3.0720061e-05 9.9990916e-01]

-------arg sorted prob-------- [5 3 1 8 0 2 7 6 9 4]

------top sorted index------- [4 9 6 7 2 0 8 1 3 5] (10,)

probability=0.999909, class=4

probability=0.000031, class=9

probability=0.000030, class=6

probability=0.000019, class=7

probability=0.000004, class=2

in[2]:

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章