Tensorflow(3):創建畫板,實時在線手寫體識別--終極篇(PyQt5)

版權聲明:本文爲博主原創文章,未經博主允許不得轉載。 https://blog.csdn.net/u011389706/article/details/81460820

   學習Tensorflow已經有一段時間了,就想能不能實現手寫體的實時在線識別,於是進行了一番探索。本文源代碼可以在這裏下載【Python3+PyQt5+Tensorflow】創建畫板,實時在線手寫體識別】
   用到的庫:Python3.6.5 + PyQt5 + PIL,編寫GUI程序,實現手寫體實時在線識別。最終實現的效果如下圖所示,在方框裏用鼠標手寫數字,左下角顯示識別結果,準確率可以達到99.2%。


1.畫板GUI及模型加載(MyMnistWindow.py)

  使用PyQt5製作了一個交互式畫板,可以用鼠標在上面寫字。畫板的程序部分參考了【Python3使用PyQt5製作簡單的畫板/手寫板】

'''
    功能:
        利用訓練好的模型,進行實時手寫體識別
    作者:yuhansgg
    博客: https://blog.csdn.net/u011389706
    日期: 2018/08/06
'''
import tensorflow as tf
from PyQt5.QtWidgets import (QWidget, QPushButton, QLabel)
from PyQt5.QtGui import (QPainter, QPen, QFont)
from PyQt5.QtCore import Qt
from PIL import ImageGrab, Image

class MyMnistWindow(QWidget):

    def __init__(self):
        super(MyMnistWindow, self).__init__()

        self.resize(284, 330)  # resize設置寬高
        self.move(100, 100)    # move設置位置
        self.setWindowFlags(Qt.FramelessWindowHint)  # 窗體無邊框
        #setMouseTracking設置爲False,否則不按下鼠標時也會跟蹤鼠標事件
        self.setMouseTracking(False)

        self.pos_xy = []  #保存鼠標移動過的點

        # 添加一系列控件
        self.label_draw = QLabel('', self)
        self.label_draw.setGeometry(2, 2, 280, 280)
        self.label_draw.setStyleSheet("QLabel{border:1px solid black;}")
        self.label_draw.setAlignment(Qt.AlignCenter)

        self.label_result_name = QLabel('識別結果:', self)
        self.label_result_name.setGeometry(2, 290, 60, 35)
        self.label_result_name.setAlignment(Qt.AlignCenter)

        self.label_result = QLabel(' ', self)
        self.label_result.setGeometry(64, 290, 35, 35)
        self.label_result.setFont(QFont("Roman times", 8, QFont.Bold))
        self.label_result.setStyleSheet("QLabel{border:1px solid black;}")
        self.label_result.setAlignment(Qt.AlignCenter)

        self.btn_recognize = QPushButton("識別", self)
        self.btn_recognize.setGeometry(110, 290, 50, 35)
        self.btn_recognize.clicked.connect(self.btn_recognize_on_clicked)

        self.btn_clear = QPushButton("清空", self)
        self.btn_clear.setGeometry(170, 290, 50, 35)
        self.btn_clear.clicked.connect(self.btn_clear_on_clicked)

        self.btn_close = QPushButton("關閉", self)
        self.btn_close.setGeometry(230, 290, 50, 35)
        self.btn_close.clicked.connect(self.btn_close_on_clicked)

    def paintEvent(self, event):
        painter = QPainter()
        painter.begin(self)
        pen = QPen(Qt.black, 30, Qt.SolidLine)
        painter.setPen(pen)

        if len(self.pos_xy) > 1:
            point_start = self.pos_xy[0]
            for pos_tmp in self.pos_xy:
                point_end = pos_tmp

                if point_end == (-1, -1):
                    point_start = (-1, -1)
                    continue
                if point_start == (-1, -1):
                    point_start = point_end
                    continue

                painter.drawLine(point_start[0], point_start[1], point_end[0], point_end[1])
                point_start = point_end
        painter.end()

    def mouseMoveEvent(self, event):
        '''
            按住鼠標移動事件:將當前點添加到pos_xy列表中
        '''
        #中間變量pos_tmp提取當前點
        pos_tmp = (event.pos().x(), event.pos().y())
        #pos_tmp添加到self.pos_xy中
        self.pos_xy.append(pos_tmp)

        self.update()

    def mouseReleaseEvent(self, event):
        '''
            重寫鼠標按住後鬆開的事件
            在每次鬆開後向pos_xy列表中添加一個斷點(-1, -1)
        '''
        pos_test = (-1, -1)
        self.pos_xy.append(pos_test)

        self.update()

    def btn_recognize_on_clicked(self):
        bbox = (104, 104, 380, 380)
        im = ImageGrab.grab(bbox)    # 截屏,手寫數字部分
        im = im.resize((28, 28), Image.ANTIALIAS)  # 將截圖轉換成 28 * 28 像素

        recognize_result = self.recognize_img(im)  # 調用識別函數

        self.label_result.setText(str(recognize_result))  # 顯示識別結果
        self.update()

    def btn_clear_on_clicked(self):
        self.pos_xy = []
        self.label_result.setText('')
        self.update()

    def btn_close_on_clicked(self):
        self.close()

    def recognize_img(self, img):  # 手寫體識別函數
        myimage = img.convert('L')  # 轉換成灰度圖
        tv = list(myimage.getdata())  # 獲取圖片像素值
        tva = [(255 - x) * 1.0 / 255.0 for x in tv]  # 轉換像素範圍到[0 1], 0是純白 1是純黑

        init = tf.global_variables_initializer()
        saver = tf.train.Saver  

        with tf.Session() as sess:
            sess.run(init)
            saver = tf.train.import_meta_graph('minst_cnn_model.ckpt.meta')  # 載入模型結構
            saver.restore(sess, 'minst_cnn_model.ckpt')  # 載入模型參數

            graph = tf.get_default_graph()  # 加載計算圖
            x = graph.get_tensor_by_name("x:0")  # 從模型中讀取佔位符變量
            keep_prob = graph.get_tensor_by_name("keep_prob:0")
            y_conv = graph.get_tensor_by_name("y_conv:0")  # 關鍵的一句  從模型中讀取佔位符變量

            prediction = tf.argmax(y_conv, 1)
            predint = prediction.eval(feed_dict={x: [tva], keep_prob: 1.0}, session=sess)  # feed_dict輸入數據給placeholder佔位符
            print(predint[0])
        return predint[0]

   識別時,先利用函數ImageGrab.grab(bbox),對屏幕畫板部分進行截圖。然後對截圖進行預處理(縮放到28*28像素,轉換成灰度圖等)。
   在最後,最重要的手寫體識別函數裏recognize_img(self, img),我們調用了已經訓練好的模型minst_cnn_model.ckpt,具體模型訓練過程,參見【Tensorflow(2):MNIST識別自己手寫的數字–進階篇(CNN)】

2.主程序(main.py)

   實例化我們上面定義的窗體類MyMnistWindow,實現窗體顯示。

import sys
from PyQt5.QtWidgets import QApplication
from MyMnistWindow import MyMnistWindow

if __name__ == "__main__":
    app = QApplication(sys.argv)
    mymnist = MyMnistWindow()
    mymnist.show()
    app.exec_()

3.實驗結果

   最終測試結果如下所示,左下角顯示識別結果。可以看到,基本都能正確識別:

4.注意事項

    若識別準確率不高,一般是由於我們手寫數字和訓練數據相差太大導致的
    一般原因是:訓練集是西方的手寫數字,和中國的手寫數字習慣不同。下面是官方的訓練數據中的部分數字。


    在畫圖時,筆法儘量和上面訓練集保持一致,就會得到較高的識別率!
    本文源代碼可以在這裏下載【Python3+PyQt5+Tensorflow】創建畫板,實時在線手寫體識別】
    是以爲記!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章