樹莓派4學習記錄(7)-實時MNIST手寫數字識別

1. 數據準備與模型訓練

1.1 數據準備與框架選擇

很明顯今天實現的是針對手寫數字集(MNIST)的實時識別,那麼很明顯:

數據集爲:標準MNIST數據集

因爲我樹莓派上安裝的是tensorflow框架,理所當然我這裏依舊選用了Tensorflow作爲我的實現框架:

框架爲:tensorflow 1.13.1

1.2 訓練模型

直接上代碼:

# codin: utf-8

# do not traing on raspberry
# the CPU occupication will go to 90%+
# boom!!!

import input_data
import tensorflow as tf

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# sess = tf.InteractiveSession()

x = tf.placeholder("float", shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10])

def weight_variable(shape):
  initial = tf.truncated_normal(shape, stddev=0.1)
  return tf.Variable(initial)

def bias_variable(shape):
  initial = tf.constant(0.1, shape=shape)
  return tf.Variable(initial)

def conv2d(x, W):
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

def max_pool_2x2(x):
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
                        strides=[1, 2, 2, 1], padding='SAME')

x_image = tf.reshape(x, [-1,28,28,1])

# layer one
W_conv1 = weight_variable([5, 5, 1, 32])
b_conv1 = bias_variable([32])

h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)

# layer two
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])

h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

# fc layer 
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])

h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

# drop out
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

# softmax
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])

y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)

# loss function
cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
# train step
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
# correct number
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
# accuracy
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
# initiate all variables
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
# start training 
steps = 2500
for i in range(steps):
    # generate a batch of images
    batch = mnist.train.next_batch(50)
    # every 100 step, training accuracy
    if i%100 == 0:
        train_accuracy = accuracy.eval(session=sess, feed_dict={
            x:batch[0], y_: batch[1], keep_prob: 1.0})
        print("step %d/%d, training accuracy %g" %(i, steps, train_accuracy))
    sess.run(train_step, feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})

saver = tf.train.Saver()
saver.save(sess, "model/model.ckpt")

print("test accuracy %g" %accuracy.eval(session=sess, feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

大概最後的準確率穩定到:

test accuracy 0.9783

有點低,但是還是夠用了。
這個代碼中使用到了一個input_data的腳本,其實可以在tensorflow的repo中找到:
input_data.py
當然,也可以改爲下面的這個代碼,以導入MNIST數據集:

from tensorflow.examples.tutorials.mnist import input_data

在訓練之後得到一個Tensorflow的模型:
在這裏插入圖片描述
模型不大,在樹莓派上運行綽綽有餘。
將模型下載到本地,等待移植到樹莓派中。

2. 樹莓派構建傳輸與識別框架

基於之前的實時UDP傳輸腳本,構建我們自己的實時UDP傳輸與數字識別功能。
還是話不多說,直接上代碼:

# coding: utf-8

import cv2
import numpy as np 
import socket
import struct
import input_data
import tensorflow as tf

# 導入MNIST數據集
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

x = tf.placeholder("float", shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10])

# cnn section
def weight_variable(shape):
	initial = tf.truncated_normal(shape, stddev=0.1)
	return tf.Variable(initial)

def bias_variable(shape):
	initial = tf.constant(0.1, shape=shape)
	return tf.Variable(initial)

def conv2d(x, W):
	return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

def max_pool_2x2(x):
	return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
                        strides=[1, 2, 2, 1], padding='SAME')

x_image = tf.reshape(x, [-1,28,28,1])

# layer one
W_conv1 = weight_variable([5, 5, 1, 32])
b_conv1 = bias_variable([32])
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)

# layer two
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

# fc layer 
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

# drop out
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

# softmax
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])

y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)

cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
# 創建session
sess = tf.Session()
# 初始化全局變量
sess.run(tf.initialize_all_variables())

# 導入模型
saver = tf.train.Saver()
saver.restore(sess, "model/model.ckpt")

# 建立套接字
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.bind(("192.168.1.7", 6000))

print("UDP bound on port 6000...")
print('now starting to send frames...')

# 創建視頻抓取對象
capture=cv2.VideoCapture(0)
# 一個小trick,用於等待接收客戶端連接
data, addr = s.recvfrom(1024)
# 設置分辨率
capture.set(3, 256)
capture.set(4, 256)

# 預熱tensorflow
print("preparing tensorflow...")
for i in range(10):
	test_batch = mnist.test.next_batch(1)
	predict_result = sess.run(y_conv, feed_dict={x: test_batch[0], y_: test_batch[1], keep_prob: 1.0})
	number = np.where(predict_result == np.max(predict_result))
	print("for the %d time "%(i+1), number[1].tolist())

# 假裝向模型輸入了label(笑)
tmp_array = np.array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])

# 主循環
print("start rec and trans....")
while True:
	success,frame=capture.read()
	while not success and frame is None:
		success,frame=capture.read() #獲取視頻幀

    # 減少周圍的環境影響,保留中心
	frame = frame[36:220, 36:220]
    # 裁剪到(28*28)
	test_image = cv2.resize(frame, (28, 28))
    # 灰度圖化
	gray = cv2.cvtColor(test_image, cv2.COLOR_BGR2GRAY)
	# 二值化
	# 注意,mnist是黑底白字,一定要注意
	ret, binary = cv2.threshold(gray,0,255,cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)

    # 運行預測
	predict_result = sess.run(y_conv, feed_dict={x: np.reshape(binary, (1,784)), y_: tmp_array, keep_prob: 1.0})
    # 輸出預測結果
	number = np.where(predict_result[0] == np.max(predict_result[0]))

    # 畫面添加預測結果
	cv2.putText(frame, "num: "+str(number[0].tolist()[0]), (0,50), cv2.FONT_HERSHEY_COMPLEX, 0.8, (100, 200, 200), 1)

	result,imgencode=cv2.imencode('.jpg',frame,[cv2.IMWRITE_JPEG_QUALITY,50])

	s.sendto(struct.pack('i',imgencode.shape[0]), addr)
	s.sendto(imgencode, addr)

s.close()

客戶端:

# coding: utf-8

import cv2
import numpy
import socket
import struct

s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
addr = ("192.168.1.7", 6000)

# 建立連接
data = 'hello'
s.sendto(data.encode(), addr)

print('now waiting for frames...')
while True:
	data, addr = s.recvfrom(65535)
	if len(data)==1 and data[0]==1: #如果收到關閉消息則停止程序
		s.close()
		cv2.destroyAllWindows()
		exit()
	if len(data)!=4: #進行簡單的校驗,長度值是int類型,佔四個字節
		length=0
	else:
		length=struct.unpack('i',data)[0] #長度值
	data,address=s.recvfrom(65535)
	if length!=len(data): #進行簡單的校驗
		continue
	data=numpy.array(bytearray(data)) #格式轉換
	imgdecode=cv2.imdecode(data,1) #解碼
	# print('have received one frame')
	cv2.imshow('frames', imgdecode) #窗口顯示
	if cv2.waitKey(1)==27: #按下“ESC”退出
		break

s.close()
cv2.destroyAllWindows()

3. 移植模型到樹莓派

其實這一步很簡單,就是將模型上傳到樹莓派的相關路徑中,並沒有什麼特別的處理,放置到:

./model/

這樣就可以直接運行了,運行方式和之前的一樣(先server,然後再client)。

4. 測試

測試結果:
在這裏插入圖片描述
在這裏插入圖片描述
在這裏插入圖片描述
效果不錯,再接再厲。

5. 鳴謝與相關跳轉

關於圖像二值化
OpenCV—圖像二值化

爲什麼MNIST需要黑底白字
Mnist模型識別自己手寫數字正確率低的原因

圖像怎麼部分截取
python數組截取

感謝以上教程,讓我少走了很多彎。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章