第一個tensorflow是從網上抄來的,但是還是爬了個大坑,在預測文件中的圖片轉換爲28*28尺寸的時候用PIL一直報錯(原作者的代碼),後來改用cv2模塊resize問題就解決了,這是一個關於數字識別的程序,程序能夠在一張只有一個0-9數字的圖片中準確識別出數字是多少,準確率高達99+%,然後我用PyQt5封裝了一下,使其可視化。環境爲Python3+tensorflow2.0+PyQt5,首先創建一個python project,然後往裏面添加文件夾,然後在v4_cnn目錄下創建三個文件,mainUI.py,predict.py,train.py三個文件 ,ckpt文件目錄是沒有的,運行程序後生成的,想要運行該程序,必須先運行訓練代碼,train.py文件,然後再運行主UI文件maiUI.py文件。
訓練文件train.py,代碼如下:
import os
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import numpy as np
y1 = [0, 0.8, 0.1, 0.1, 0, 0, 0, 0, 0, 0]
y2 = [0, 0.1, 0.1, 0.1, 0.5, 0, 0.2, 0, 0, 0]
np.argmax(y1) # 1
np.argmax(y2) # 4
class CNN(object):
def __init__(self):
model = models.Sequential()
# 第1層卷積,卷積核大小爲3*3,32個,28*28爲待訓練圖片的大小
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
# 第2層卷積,卷積核大小爲3*3,64個
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
# 第3層卷積,卷積核大小爲3*3,64個
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
model.summary()
self.model = model
class DataSource(object):
def __init__(self):
# mnist數據集存儲的位置,如何不存在將自動下載
data_path = os.path.abspath(os.path.dirname(__file__)) + '/../data_set_tf2/mnist.npz'
(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data(path=data_path)
# 6萬張訓練圖片,1萬張測試圖片
train_images = train_images.reshape((60000, 28, 28, 1))
test_images = test_images.reshape((10000, 28, 28, 1))
# 像素值映射到 0 - 1 之間
train_images, test_images = train_images / 255.0, test_images / 255.0
self.train_images, self.train_labels = train_images, train_labels
self.test_images, self.test_labels = test_images, test_labels
class Train:
def __init__(self):
self.cnn = CNN()
self.data = DataSource()
def train(self):
check_path = './ckpt/cp-{epoch:04d}.ckpt'
# period 每隔5epoch保存一次
save_model_cb = tf.keras.callbacks.ModelCheckpoint(check_path, save_weights_only=True, verbose=1, period=5)
self.cnn.model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
self.cnn.model.fit(self.data.train_images, self.data.train_labels, epochs=5, callbacks=[save_model_cb])
test_loss, test_acc = self.cnn.model.evaluate(self.data.test_images, self.data.test_labels)
print("準確率: %.4f,共測試了%d張圖片 " % (test_acc, len(self.data.test_labels)))
if __name__ == "__main__":
app = Train()
app.train()
預測文件predict.py,代碼如下:
import tensorflow as tf
from PIL import Image
import numpy as np
from v4_cnn.train import CNN
import cv2
class Predict(object):
def __init__(self):
latest = tf.train.latest_checkpoint('./ckpt')
self.cnn = CNN()
# 恢復網絡權重
self.cnn.model.load_weights(latest)
def predict(self, image_path):
# 以黑白方式讀取圖片
img = Image.open(image_path).convert('L') #爬了個大坑
img = np.asarray(img)
img = cv2.resize(img,(28,28))
flatten_img = np.reshape(img, (28, 28, 1))
x = np.array([1 - flatten_img])
# API refer: https://keras.io/models/model/
self.y = self.cnn.model.predict(x)
主UI文件mainUI.py,代碼如下:
import cv2
import sys
from PyQt5 import QtGui
from PyQt5.QtCore import Qt
from PyQt5.QtWidgets import QApplication, QWidget, QLabel, QHBoxLayout, QMainWindow, QDockWidget, QPushButton, \
QVBoxLayout, QTextEdit, QFileDialog
from v4_cnn.predict import Predict
import numpy as np
class QPixmapDemo(QMainWindow):
def __init__(self):
super().__init__()
self.txt = {0:'0', 1:'1', 2:'2', 3:'3', 4:'4', 5:'5', 6:'6', 7:'7', 8:'8', 9:'9'}
self.setWindowTitle('picture')
self.wgt = QWidget()
# self.wgt.resize(600, 500)
self.imgLabel = QLabel()
self.imgLabel.resize(600, 600) # 設置label的大小,圖片會適配label的大小
self.hbox = QHBoxLayout()
self.hbox.addWidget(self.imgLabel)
self.wgt.setLayout(self.hbox)
self.setCentralWidget(self.wgt)
self.docker = docker(self)
self.addDockWidget(Qt.LeftDockWidgetArea,self.docker)
self.docker.btn_openFile.clicked.connect(self.openFile)
self.docker.btn_startDiscern.clicked.connect(self.start)
self.resize(800,600)
def openFile(self):
self.file, filetype = QFileDialog.getOpenFileName(self,
"選擇只有一個數字的圖片",
"./",
"All Files (*);;Text Files (*.txt)")
if self.file is not None:
self.setImage(self.file)
def start(self):
discern = Predict()
discern.predict(self.file)
num = np.argmax(discern.y[0])
self.docker.texEdit.setText(str(num))
def setImage(self, file):
img = cv2.imread(file) # opencv讀取圖片
img2 = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # opencv讀取的bgr格式圖片轉換成rgb格式
_image = QtGui.QImage(img2[:], img2.shape[1], img2.shape[0], img2.shape[1] * 3,
QtGui.QImage.Format_RGB888) # pyqt5轉換成自己能放的圖片格式
jpg_out = QtGui.QPixmap(_image).scaled(self.imgLabel.width(), self.imgLabel.height()) # 設置圖片大小
self.imgLabel.setPixmap(jpg_out) # 設置圖片顯示
class docker(QDockWidget):
def __init__(self, parent):
super().__init__(parent)
self.btn_openFile = QPushButton('打開圖片')
self.btn_startDiscern = QPushButton('開始識別')
self.texEdit = QTextEdit()
self.vbox = QVBoxLayout()
self.vbox.addWidget(self.btn_openFile)
self.vbox.addWidget(self.btn_startDiscern)
self.vbox.addWidget(self.texEdit)
self.wgt = QWidget()
self.wgt.setLayout(self.vbox)
self.setWidget(self.wgt)
if __name__ == '__main__':
app = QApplication(sys.argv)
win = QPixmapDemo()
win.show()
sys.exit(app.exec_())
運行結果: