我的第一個tensorflow程序

第一個tensorflow是從網上抄來的,但是還是爬了個大坑,在預測文件中的圖片轉換爲28*28尺寸的時候用PIL一直報錯(原作者的代碼),後來改用cv2模塊resize問題就解決了,這是一個關於數字識別的程序,程序能夠在一張只有一個0-9數字的圖片中準確識別出數字是多少,準確率高達99+%,然後我用PyQt5封裝了一下,使其可視化。環境爲Python3+tensorflow2.0+PyQt5,首先創建一個python project,然後往裏面添加文件夾,然後在v4_cnn目錄下創建三個文件,mainUI.py,predict.py,train.py三個文件 ,ckpt文件目錄是沒有的,運行程序後生成的,想要運行該程序,必須先運行訓練代碼,train.py文件,然後再運行主UI文件maiUI.py文件。

 

訓練文件train.py,代碼如下:

import os
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import numpy as np

y1 = [0, 0.8, 0.1, 0.1, 0, 0, 0, 0, 0, 0]
y2 = [0, 0.1, 0.1, 0.1, 0.5, 0, 0.2, 0, 0, 0]
np.argmax(y1) # 1
np.argmax(y2) # 4

class CNN(object):
    def __init__(self):
        model = models.Sequential()
        # 第1層卷積,卷積核大小爲3*3,32個,28*28爲待訓練圖片的大小
        model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
        model.add(layers.MaxPooling2D((2, 2)))
        # 第2層卷積,卷積核大小爲3*3,64個
        model.add(layers.Conv2D(64, (3, 3), activation='relu'))
        model.add(layers.MaxPooling2D((2, 2)))
        # 第3層卷積,卷積核大小爲3*3,64個
        model.add(layers.Conv2D(64, (3, 3), activation='relu'))
        model.add(layers.Flatten())
        model.add(layers.Dense(64, activation='relu'))
        model.add(layers.Dense(10, activation='softmax'))
        model.summary()
        self.model = model

class DataSource(object):
    def __init__(self):
        # mnist數據集存儲的位置,如何不存在將自動下載
        data_path = os.path.abspath(os.path.dirname(__file__)) + '/../data_set_tf2/mnist.npz'
        (train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data(path=data_path)
        # 6萬張訓練圖片,1萬張測試圖片
        train_images = train_images.reshape((60000, 28, 28, 1))
        test_images = test_images.reshape((10000, 28, 28, 1))
        # 像素值映射到 0 - 1 之間
        train_images, test_images = train_images / 255.0, test_images / 255.0

        self.train_images, self.train_labels = train_images, train_labels
        self.test_images, self.test_labels = test_images, test_labels

class Train:
    def __init__(self):
        self.cnn = CNN()
        self.data = DataSource()

    def train(self):
        check_path = './ckpt/cp-{epoch:04d}.ckpt'
        # period 每隔5epoch保存一次
        save_model_cb = tf.keras.callbacks.ModelCheckpoint(check_path, save_weights_only=True, verbose=1, period=5)
        self.cnn.model.compile(optimizer='adam',
                               loss='sparse_categorical_crossentropy',
                               metrics=['accuracy'])
        self.cnn.model.fit(self.data.train_images, self.data.train_labels, epochs=5, callbacks=[save_model_cb])
        test_loss, test_acc = self.cnn.model.evaluate(self.data.test_images, self.data.test_labels)
        print("準確率: %.4f,共測試了%d張圖片 " % (test_acc, len(self.data.test_labels)))


if __name__ == "__main__":
    app = Train()
    app.train()

預測文件predict.py,代碼如下:

import tensorflow as tf
from PIL import Image
import numpy as np
from v4_cnn.train import CNN
import cv2


class Predict(object):
    def __init__(self):
        latest = tf.train.latest_checkpoint('./ckpt')
        self.cnn = CNN()
        # 恢復網絡權重
        self.cnn.model.load_weights(latest)

    def predict(self, image_path):
        # 以黑白方式讀取圖片
        img = Image.open(image_path).convert('L') #爬了個大坑
        img = np.asarray(img)
        img = cv2.resize(img,(28,28))
        flatten_img = np.reshape(img, (28, 28, 1))
        x = np.array([1 - flatten_img])
        # API refer: https://keras.io/models/model/
        self.y = self.cnn.model.predict(x)

主UI文件mainUI.py,代碼如下:

import cv2
import sys
from PyQt5 import QtGui
from PyQt5.QtCore import Qt
from PyQt5.QtWidgets import QApplication, QWidget, QLabel, QHBoxLayout, QMainWindow, QDockWidget, QPushButton, \
    QVBoxLayout, QTextEdit, QFileDialog
from v4_cnn.predict import Predict
import numpy as np

class QPixmapDemo(QMainWindow):
    def __init__(self):
        super().__init__()
        self.txt = {0:'0', 1:'1', 2:'2', 3:'3', 4:'4', 5:'5', 6:'6', 7:'7', 8:'8', 9:'9'}
        self.setWindowTitle('picture')
        self.wgt = QWidget()
        # self.wgt.resize(600, 500)
        self.imgLabel = QLabel()
        self.imgLabel.resize(600, 600)  # 設置label的大小,圖片會適配label的大小
        self.hbox = QHBoxLayout()
        self.hbox.addWidget(self.imgLabel)
        self.wgt.setLayout(self.hbox)
        self.setCentralWidget(self.wgt)
        self.docker = docker(self)
        self.addDockWidget(Qt.LeftDockWidgetArea,self.docker)
        self.docker.btn_openFile.clicked.connect(self.openFile)
        self.docker.btn_startDiscern.clicked.connect(self.start)
        self.resize(800,600)

    def openFile(self):
        self.file, filetype = QFileDialog.getOpenFileName(self,
                                                          "選擇只有一個數字的圖片",
                                                          "./",
                                                          "All Files (*);;Text Files (*.txt)")
        if self.file is not None:
            self.setImage(self.file)

    def start(self):
        discern = Predict()
        discern.predict(self.file)
        num = np.argmax(discern.y[0])
        self.docker.texEdit.setText(str(num))

    def setImage(self, file):
        img = cv2.imread(file)  # opencv讀取圖片
        img2 = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # opencv讀取的bgr格式圖片轉換成rgb格式
        _image = QtGui.QImage(img2[:], img2.shape[1], img2.shape[0], img2.shape[1] * 3,
                              QtGui.QImage.Format_RGB888)  # pyqt5轉換成自己能放的圖片格式
        jpg_out = QtGui.QPixmap(_image).scaled(self.imgLabel.width(), self.imgLabel.height())  # 設置圖片大小
        self.imgLabel.setPixmap(jpg_out)  # 設置圖片顯示


class docker(QDockWidget):
    def __init__(self, parent):
        super().__init__(parent)
        self.btn_openFile = QPushButton('打開圖片')
        self.btn_startDiscern = QPushButton('開始識別')
        self.texEdit = QTextEdit()
        self.vbox = QVBoxLayout()
        self.vbox.addWidget(self.btn_openFile)
        self.vbox.addWidget(self.btn_startDiscern)
        self.vbox.addWidget(self.texEdit)
        self.wgt = QWidget()
        self.wgt.setLayout(self.vbox)
        self.setWidget(self.wgt)


if __name__ == '__main__':
    app = QApplication(sys.argv)
    win = QPixmapDemo()
    win.show()
    sys.exit(app.exec_())

運行結果:

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章