mxnet mnist

import numpy as np
import os
import urllib
import urllib.request
import gzip
import struct
import matplotlib.pyplot as plt
import mxnet as mx
import logging

from IPython.display import HTML
import cv2
#from mnist_demo import html, script

def download_data(url, force_download=True):
    fname = url.split("/")[-1]
    if force_download or not os.path.exists(fname):
        urllib.request.urlretrieve(url,fname) # python2 與python3的urllib不同在與python3要加上.request
    return fname
	
def read_data(label_url, image_url):
    with gzip.open(download_data(label_url))as flbl:
        magic, num = struct.unpack(">II", flbl.read(8))
        label = np.fromstring(flbl.read(), dtype=np.int8)
    with gzip.open(download_data(image_url),'rb')as fimg:
        magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
        image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols)
    return (label, image)
	
def get_data(path):
	(train_lbl, train_img)= read_data(
    path+'train-labels-idx1-ubyte.gz', path+'train-images-idx3-ubyte.gz')
	
	(val_lbl, val_img) = read_data(
    path+'t10k-labels-idx1-ubyte.gz', path+'t10k-images-idx3-ubyte.gz')
	
	print("train_img.shape=", train_img.shape)
	print("train_lbl.shape=", train_lbl.shape)
	print("val_lbl.shape=", val_lbl.shape)
	print("val_img.shape=", val_img.shape)
	
	for i in range(10):
		plt.subplot(1,10,i+1)
		plt.imshow(train_img[i], cmap='Greys_r')
		plt.axis('off')
	plt.show()
	print('label: %s'% (train_lbl[0:10],))
	
	return (train_lbl, train_img, val_lbl, val_img)

def to4d(img):
    return img.reshape(img.shape[0],1,28,28).astype(np.float32)/255
	
def bpnetWorkTrainMnist(train_lbl, train_img, val_lbl, val_img):
	batch_size=100
	train_iter= mx.io.NDArrayIter(to4d(train_img), train_lbl, batch_size, shuffle=True)
	val_iter= mx.io.NDArrayIter(to4d(val_img), val_lbl, batch_size)
	
	# Create a place holder variable for the input data
	data= mx.sym.Variable('data')
	
	# Flatten the data from 4-D shape (batch_size, num_channel, width, height)
	# into 2-D (batch_size, num_channel*width*height)
	data= mx.sym.Flatten(data=data)
	
	# The first fully-connected layer
	fc1 = mx.sym.FullyConnected(data=data, name='fc1', num_hidden=128)
	
	# Apply relu to the output of the first fully-connnected layer
	act1= mx.sym.Activation(data=fc1, name='relu1', act_type="relu")
	
	# The second fully-connected layer and the according activation function
	fc2 = mx.sym.FullyConnected(data=act1, name='fc2', num_hidden =64)
	
	act2= mx.sym.Activation(data=fc2, name='relu2', act_type="relu")
	
	# The thrid fully-connected layer, note that the hidden size should be 10, which is the number of unique digits
	fc3  = mx.sym.FullyConnected(data=act2, name='fc3', num_hidden=10)
	
	# The softmax and loss layer
	mlp  = mx.sym.SoftmaxOutput(data=fc3, name='softmax')
	
	# We visualize the network structure with output size (the batch_size is ignored.)
	shape= {"data" : (batch_size, 1,28,28)}
	mx.viz.plot_network(symbol=mlp, shape=shape)
	
	# 現在神經網絡定義和數據迭代器都已經準備好了。我們可以開始訓練了:
	logging.getLogger().setLevel(logging.DEBUG)
	
	model= mx.model.FeedForward(
		symbol = mlp,       # network structure
		num_epoch =10,     	# number of data passes for training 
		learning_rate =0.1	# learning rate of SGD 
	)
	
	model.fit(
		X=train_iter,       # training data
		eval_data=val_iter,# validation data
		batch_end_callback = mx.callback.Speedometer(batch_size,200)# output progress for each 200 data batches
	)	
		
	return (model, val_iter)

def classify(model, img):
	img = img[len('data:image/png;base64,'):].decode('base64')
	img = cv2.imdecode(np.fromstring(img, np.uint8),-1)
	img = cv2.resize(img[:,:,3], (28,28))
	img = img.astype(np.float32).reshape((1,1,28,28))/255.0
	return model.predict(img)[0].argmax()
	
def cnnTrainMnist(train_lbl, train_img, val_lbl, val_img):
	batch_size=100
	train_iter= mx.io.NDArrayIter(to4d(train_img), train_lbl, batch_size, shuffle=True)
	val_iter= mx.io.NDArrayIter(to4d(val_img), val_lbl, batch_size)
	
	data= mx.symbol.Variable('data')
	
	# first conv layer
	conv1= mx.sym.Convolution(data=data, kernel=(5,5), num_filter=20)
	tanh1= mx.sym.Activation(data=conv1, act_type="tanh")
	pool1= mx.sym.Pooling(data=tanh1, pool_type="max", kernel=(2,2), stride=(2,2))
	
	# second conv layer
	conv2= mx.sym.Convolution(data=pool1, kernel=(5,5), num_filter=50)
	tanh2= mx.sym.Activation(data=conv2, act_type="tanh")
	pool2= mx.sym.Pooling(data=tanh2, pool_type="max", kernel=(2,2), stride=(2,2))
	
	# first fullc layer
	flatten= mx.sym.Flatten(data=pool2)
	fc1= mx.symbol.FullyConnected(data=flatten, num_hidden=500)
	tanh3= mx.sym.Activation(data=fc1, act_type="tanh")
	
	# second fullc
	fc2 = mx.symbol.FullyConnected(data=tanh3, num_hidden=10)
	
	# softmax loss
	lenet= mx.sym.SoftmaxOutput(data=fc2, name='softmax')
	
	# We visualize the network structure with output size (the batch_size is ignored.)
	shape= {"data" : (batch_size, 1,28,28)}
	mx.viz.plot_network(symbol=lenet, shape=shape)
	
	# 現在神經網絡定義和數據迭代器都已經準備好了。我們可以開始訓練了:
	logging.getLogger().setLevel(logging.DEBUG)
	
	model= mx.model.FeedForward(
		ctx = mx.gpu(0),     # use GPU 0 for training, others are same as before
		symbol = lenet,       
		num_epoch =10,     
		learning_rate = 0.1
	)
	
	model.fit(
		X=train_iter,  
		eval_data=val_iter,
		batch_end_callback = mx.callback.Speedometer(batch_size,200)
	)
	
	return (model, val_iter)
	
def mnist():
	path='http://yann.lecun.com/exdb/mnist/'
	train_lbl, train_img, val_lbl, val_img = get_data(path)
	
	model, val_iter = bpnetWorkTrainMnist(train_lbl, train_img, val_lbl, val_img)
	
	#model, val_iter = cnnTrainMnist(train_lbl, train_img, val_lbl, val_img)
	
	#model.save_checkpoint('mxnet_mnist_weight',num_epoch) # 保存模型
	
	# 完成訓練後,我們對單幅圖片進行測試。
	plt.imshow(val_img[0], cmap='Greys_r')
	plt.axis('off')
	plt.show()
	
	# bpNetWork predict
	#prob= model.predict(val_img[0:1].astype(np.float32)/255)[0]
	
	# cnn Predict
	prob = model.predict((val_img[0:1].astype(np.float32)/255).reshape(1, 1, 28, 28))[0]
	
	print('Classified as %d with probability %f'% (prob.argmax(),max(prob)))
	
	# 我們也可以通過給予一個數據迭代器來計算正確率。
	print('Validation accuracy: %f%%'% (model.score(val_iter)*100,))
	
	
	
def main():
	mnist()
	
if __name__ == '__main__':
	main()

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章