學習Tensorflow已經有一段時間了,就想能不能實現手寫體的實時在線識別,於是進行了一番探索。本文源代碼可以在這裏下載【Python3+PyQt5+Tensorflow】創建畫板,實時在線手寫體識別】
用到的庫:Python3.6.5 + PyQt5 + PIL,編寫GUI程序,實現手寫體實時在線識別。最終實現的效果如下圖所示,在方框裏用鼠標手寫數字,左下角顯示識別結果,準確率可以達到99.2%。
1.畫板GUI及模型加載(MyMnistWindow.py)
使用PyQt5製作了一個交互式畫板,可以用鼠標在上面寫字。畫板的程序部分參考了【Python3使用PyQt5製作簡單的畫板/手寫板】。
'''
功能:
利用訓練好的模型,進行實時手寫體識別
作者:yuhansgg
博客: https://blog.csdn.net/u011389706
日期: 2018/08/06
'''
import tensorflow as tf
from PyQt5.QtWidgets import (QWidget, QPushButton, QLabel)
from PyQt5.QtGui import (QPainter, QPen, QFont)
from PyQt5.QtCore import Qt
from PIL import ImageGrab, Image
class MyMnistWindow(QWidget):
def __init__(self):
super(MyMnistWindow, self).__init__()
self.resize(284, 330) # resize設置寬高
self.move(100, 100) # move設置位置
self.setWindowFlags(Qt.FramelessWindowHint) # 窗體無邊框
#setMouseTracking設置爲False,否則不按下鼠標時也會跟蹤鼠標事件
self.setMouseTracking(False)
self.pos_xy = [] #保存鼠標移動過的點
# 添加一系列控件
self.label_draw = QLabel('', self)
self.label_draw.setGeometry(2, 2, 280, 280)
self.label_draw.setStyleSheet("QLabel{border:1px solid black;}")
self.label_draw.setAlignment(Qt.AlignCenter)
self.label_result_name = QLabel('識別結果:', self)
self.label_result_name.setGeometry(2, 290, 60, 35)
self.label_result_name.setAlignment(Qt.AlignCenter)
self.label_result = QLabel(' ', self)
self.label_result.setGeometry(64, 290, 35, 35)
self.label_result.setFont(QFont("Roman times", 8, QFont.Bold))
self.label_result.setStyleSheet("QLabel{border:1px solid black;}")
self.label_result.setAlignment(Qt.AlignCenter)
self.btn_recognize = QPushButton("識別", self)
self.btn_recognize.setGeometry(110, 290, 50, 35)
self.btn_recognize.clicked.connect(self.btn_recognize_on_clicked)
self.btn_clear = QPushButton("清空", self)
self.btn_clear.setGeometry(170, 290, 50, 35)
self.btn_clear.clicked.connect(self.btn_clear_on_clicked)
self.btn_close = QPushButton("關閉", self)
self.btn_close.setGeometry(230, 290, 50, 35)
self.btn_close.clicked.connect(self.btn_close_on_clicked)
def paintEvent(self, event):
painter = QPainter()
painter.begin(self)
pen = QPen(Qt.black, 30, Qt.SolidLine)
painter.setPen(pen)
if len(self.pos_xy) > 1:
point_start = self.pos_xy[0]
for pos_tmp in self.pos_xy:
point_end = pos_tmp
if point_end == (-1, -1):
point_start = (-1, -1)
continue
if point_start == (-1, -1):
point_start = point_end
continue
painter.drawLine(point_start[0], point_start[1], point_end[0], point_end[1])
point_start = point_end
painter.end()
def mouseMoveEvent(self, event):
'''
按住鼠標移動事件:將當前點添加到pos_xy列表中
'''
#中間變量pos_tmp提取當前點
pos_tmp = (event.pos().x(), event.pos().y())
#pos_tmp添加到self.pos_xy中
self.pos_xy.append(pos_tmp)
self.update()
def mouseReleaseEvent(self, event):
'''
重寫鼠標按住後鬆開的事件
在每次鬆開後向pos_xy列表中添加一個斷點(-1, -1)
'''
pos_test = (-1, -1)
self.pos_xy.append(pos_test)
self.update()
def btn_recognize_on_clicked(self):
bbox = (104, 104, 380, 380)
im = ImageGrab.grab(bbox) # 截屏,手寫數字部分
im = im.resize((28, 28), Image.ANTIALIAS) # 將截圖轉換成 28 * 28 像素
recognize_result = self.recognize_img(im) # 調用識別函數
self.label_result.setText(str(recognize_result)) # 顯示識別結果
self.update()
def btn_clear_on_clicked(self):
self.pos_xy = []
self.label_result.setText('')
self.update()
def btn_close_on_clicked(self):
self.close()
def recognize_img(self, img): # 手寫體識別函數
myimage = img.convert('L') # 轉換成灰度圖
tv = list(myimage.getdata()) # 獲取圖片像素值
tva = [(255 - x) * 1.0 / 255.0 for x in tv] # 轉換像素範圍到[0 1], 0是純白 1是純黑
init = tf.global_variables_initializer()
saver = tf.train.Saver
with tf.Session() as sess:
sess.run(init)
saver = tf.train.import_meta_graph('minst_cnn_model.ckpt.meta') # 載入模型結構
saver.restore(sess, 'minst_cnn_model.ckpt') # 載入模型參數
graph = tf.get_default_graph() # 加載計算圖
x = graph.get_tensor_by_name("x:0") # 從模型中讀取佔位符變量
keep_prob = graph.get_tensor_by_name("keep_prob:0")
y_conv = graph.get_tensor_by_name("y_conv:0") # 關鍵的一句 從模型中讀取佔位符變量
prediction = tf.argmax(y_conv, 1)
predint = prediction.eval(feed_dict={x: [tva], keep_prob: 1.0}, session=sess) # feed_dict輸入數據給placeholder佔位符
print(predint[0])
return predint[0]
識別時,先利用函數ImageGrab.grab(bbox),對屏幕畫板部分進行截圖。然後對截圖進行預處理(縮放到28*28像素,轉換成灰度圖等)。
在最後,最重要的手寫體識別函數裏recognize_img(self, img),我們調用了已經訓練好的模型minst_cnn_model.ckpt,具體模型訓練過程,參見【Tensorflow(2):MNIST識別自己手寫的數字–進階篇(CNN)】
2.主程序(main.py)
實例化我們上面定義的窗體類MyMnistWindow,實現窗體顯示。
import sys
from PyQt5.QtWidgets import QApplication
from MyMnistWindow import MyMnistWindow
if __name__ == "__main__":
app = QApplication(sys.argv)
mymnist = MyMnistWindow()
mymnist.show()
app.exec_()
3.實驗結果
最終測試結果如下所示,左下角顯示識別結果。可以看到,基本都能正確識別:
4.注意事項
若識別準確率不高,一般是由於我們手寫數字和訓練數據相差太大導致的。
一般原因是:訓練集是西方的手寫數字,和中國的手寫數字習慣不同。下面是官方的訓練數據中的部分數字。
在畫圖時,筆法儘量和上面訓練集保持一致,就會得到較高的識別率!。
本文源代碼可以在這裏下載【Python3+PyQt5+Tensorflow】創建畫板,實時在線手寫體識別】
是以爲記!