環境:
Anaconda3(64-bit),安裝mxnet1.3.1,opencv_python-3.4.5.20-cp36-cp36m-win_amd64.whl(可選)
訓練源碼:
# -*- coding: utf-8 -*-
"""
Created on Fri Jul 19 16:30:15 2019
@author: houwenbin
"""
import numpy as np
import mxnet as mx
import logging
logging.getLogger().setLevel(logging.DEBUG)
batch_size = 100
mnist = mx.test_utils.get_mnist()
train_iter = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'], batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)
data = mx.sym.var('data')
# first conv layer
conv1= mx.sym.Convolution(data=data, kernel=(5,5), num_filter=20)
tanh1= mx.sym.Activation(data=conv1, act_type="tanh")
pool1= mx.sym.Pooling(data=tanh1, pool_type="max", kernel=(2,2), stride=(2,2))
# second conv layer
conv2= mx.sym.Convolution(data=pool1, kernel=(5,5), num_filter=50)
tanh2= mx.sym.Activation(data=conv2, act_type="tanh")
pool2= mx.sym.Pooling(data=tanh2, pool_type="max", kernel=(2,2), stride=(2,2))
# first fullc layer
flatten= mx.sym.Flatten(data=pool2)
fc1= mx.symbol.FullyConnected(data=flatten, num_hidden=500)
tanh3= mx.sym.Activation(data=fc1, act_type="tanh")
# second fullc
fc2= mx.sym.FullyConnected(data=tanh3, num_hidden=10)
# softmax loss
lenet= mx.sym.SoftmaxOutput(data=fc2, name='softmax')
# create a trainable module on GPU 0
lenet_model = mx.mod.Module(
symbol=lenet,
context=mx.cpu())
# train with the same
lenet_model.fit(train_iter,
eval_data=val_iter,
optimizer='sgd',
optimizer_params={'learning_rate':0.1},
eval_metric='acc',
batch_end_callback = mx.callback.Speedometer(batch_size, 100),
num_epoch=10)
# save model params
#lenet_model.save_params("lenet_10.params");
#
lenet_model.save_checkpoint("lenet", 10, False);
預測源碼:
# -*- coding: utf-8 -*-
"""
Created on Fri Jul 19 20:17:26 2019
@author: houwenbin
"""
import time
import mxnet as mx
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
#
prefix = 'lenet'
iteration = 10
img_name = './digit_8.jpg'
synsets = [0,1,2,3,4,5,6,7,8,9]
# imagenet 圖像預處理
def load_image(img_name):
#PIL
#相關:scipy.misc.imread, scipy.ndimage.imread
#misc.imread 提供可選參數mode,但本質上是調用PIL,具體的模式可以去看srccode或者document
img = Image.open(img_name)
if img is None:
return None
img = img.resize((28,28))
img = np.array(img.convert('L'),'f') #讀取圖片,灰度化,轉換爲數組,L = 0.299R + 0.587G + 0.114B。'f'爲float類型
#統一使用plt進行顯示,不管是plt還是cv2.imshow,在python中只認numpy.array,但是由於cv2.imread 的圖片是BGR,cv2.imshow 時相應的換通道顯示
print(img.shape)
plt.imshow(img)
plt.show()
#
img = img.reshape(1,1,28,28).astype(np.float32)/255
return img
time0 = time.time()
# 加載 mxnet symbol
sym, arg, aux = mx.model.load_checkpoint(prefix, iteration)
# 重建模型
mod = mx.mod.Module(symbol=sym, context=mx.cpu(), label_names=None)
mod.bind(for_training=False, data_shapes=[('data',(1,1,28,28))], label_shapes=mod._label_shapes) # 爲輸入數據分配內存
mod.set_params(arg, aux, allow_missing=True) # 加載模型參數
#
time1 = time.time()
print("模型加載和重建時間:{0}".format(time1 - time0))
#
#加載圖片
img = load_image(img_name)
if img is None:
exit()
print(img.shape)
#
time0 = time.time()
#
# define a simple data batch
from collections import namedtuple
Batch = namedtuple('Batch', ['data'])
#
# compute the predict probabilities
mod.forward(Batch([mx.nd.array(img)])) # img{NDArray 1x1x28x28}做簡單的inference
#
time1 = time.time()
print("前向預測時間:{0}".format(time1 - time0))
#輸出Top-5預測結果
print(mod.get_outputs())
prob = mod.get_outputs()[0].asnumpy() #取出結果
print("-------result-------", prob, prob.shape)
prob = np.squeeze(prob)
print("-------squeeze result-------", prob, prob.shape)
print("-------sorted prob--------", np.sort(prob)) # 從小到大排列
print("-------arg sorted prob--------", np.argsort(prob))
a = np.argsort(prob)[::-1] # 得到分類網絡分類置信度的從大到小的結果
print("------top sorted index-------", a, a.shape)
if a is not None:
for i in a[0:5]:
print('probability=%f, class=%s' %(prob[i], synsets[i]))
數據準備:
使用畫圖工具,繪製一個128x128的黑色背景,用橡皮擦擦除待檢測數字即可(本文是digit_8.jpg)。
運行結果:
in[1]:runfile('C:/Users/houwenbin/Documents/PythonProject/test_mnist.py', wdir='C:/Users/houwenbin/Documents/PythonProject')
模型加載和重建時間:0.0060160160064697266
(28, 28)
(1, 1, 28, 28)
前向預測時間:0.0010042190551757812
[
[[3.0556594e-06 1.3175709e-06 4.1811345e-06 1.1044953e-08 9.9990916e-01
4.0004899e-10 3.0342795e-05 1.8727254e-05 2.3288235e-06 3.0720061e-05]]
<NDArray 1x10 @cpu(0)>]
-------result------- [[3.0556594e-06 1.3175709e-06 4.1811345e-06 1.1044953e-08 9.9990916e-01
4.0004899e-10 3.0342795e-05 1.8727254e-05 2.3288235e-06 3.0720061e-05]] (1, 10)
-------squeeze result------- [3.0556594e-06 1.3175709e-06 4.1811345e-06 1.1044953e-08 9.9990916e-01
4.0004899e-10 3.0342795e-05 1.8727254e-05 2.3288235e-06 3.0720061e-05] (10,)
-------sorted prob-------- [4.0004899e-10 1.1044953e-08 1.3175709e-06 2.3288235e-06 3.0556594e-06
4.1811345e-06 1.8727254e-05 3.0342795e-05 3.0720061e-05 9.9990916e-01]
-------arg sorted prob-------- [5 3 1 8 0 2 7 6 9 4]
------top sorted index------- [4 9 6 7 2 0 8 1 3 5] (10,)
probability=0.999909, class=4
probability=0.000031, class=9
probability=0.000030, class=6
probability=0.000019, class=7
probability=0.000004, class=2
in[2]: