用卷積神經網絡進行mnist手寫體識別

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np


def variable_init(shape):
    """
    定義一個變量初始化函數
    :return:
    """
    return tf.Variable(initial_value=tf.random.normal(shape=shape))


def mnist_cnn(x_train):
    """
    構建卷積神經網絡進行特徵提取
    :param x:
    :return:
    """
    # 先對x[None, 784]的階數進行修改,改爲四階[None, 28, 28, 1](batch, height, width, channels)
    # 注意:reshape中-1作爲未知數佔位符, 因爲在訓練集中,輸入樣本batch爲100,而在測試集中每次輸入一個樣本
    # 進行預測。所以如果指定了reshape中的batch大小,則會在訓練完成後進行測試時出現數據格式問題。
    input_x = tf.reshape(x_train, shape=[-1, 28, 28, 1])
    with tf.variable_scope("conv1"):
        # 卷積層1
        # 設置卷積核(就是設置權重和偏置)
        filter1_Weights = variable_init([5, 5, 1, 32])
        filter1_bias = variable_init([32])
        conv1 = tf.nn.conv2d(input=input_x, filter=filter1_Weights, strides=[1, 1, 1, 1], padding="SAME") + filter1_bias
        # 激活函數:Relu
        conv1_relu = tf.nn.relu(conv1)
        # 池化層
        conv1_pool = tf.nn.max_pool(value=conv1_relu, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")
    with tf.variable_scope("conv2"):
        # 卷積層2
        filter2_Weights = variable_init([5, 5, 32, 64])
        filter2_bias = variable_init([64])
        conv2 = tf.nn.conv2d(input=conv1_pool, filter=filter2_Weights, strides=[1, 1, 1, 1], padding="SAME") + filter2_bias
        # 激活函數:Relu
        conv2_relu = tf.nn.relu(conv2)
        # 池化層
        conv2_pool = tf.nn.max_pool(value=conv2_relu, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")
    with tf.variable_scope("fc"):
        # 全連接層
        # 因爲要進行二階矩陣相乘,先改變形狀
        fc_input = tf.reshape(conv2_pool, shape=[-1, 7*7*64])
        fc_Weights = variable_init([7*7*64, 10])
        fc_bias = variable_init([10])
    y_predict = tf.matmul(fc_input, fc_Weights) + fc_bias

    return y_predict


def mnist_recognition():
    """
    使用cnn進行手寫體識別
    :return:
    """
    # 1、準備數據
    #    兩種數據讀取方式:
    #   (1)、QueueRunner
    #   (2)、Feeding
    mnist = input_data.read_data_sets(r"E:\GameDownload\dataset_mnist", one_hot=True)
    x_train = tf.placeholder(dtype=tf.float32, shape=[None, 784])
    y_true = tf.placeholder(dtype=tf.float32, shape=[None, 10])

    # 2、構建cnn模型(注意模型參數應用變量存儲)
    y_predict = mnist_cnn(x_train)
    # print(y_predict)

    # 3、構造損失函數(用softmax表示的交叉熵)
    error = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_predict))
    # 4、優化損失(使用梯度下降方法)
    optimizer = tf.train.AdamOptimizer(learning_rate=0.02).minimize(error)
    # optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(error)
    # 5、計算準確率,對y_predict使用argmax可以找出其一行中最大值所在的列
    #    由於使用的是one-hot編碼,所以預測值與真實值在編碼內的位置相同時爲true,否則爲false
    #    之後將bool值轉爲浮點數後求均值,即爲一個batch內true的機率
    equal_list = tf.equal(tf.argmax(y_predict, 1), tf.argmax(y_true, 1))
    accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32))
    init = tf.global_variables_initializer()
    with tf.compat.v1.Session() as sess:
        sess.run(init)
        # 給出一次批處理加載的圖片樣本個數
        image, label = mnist.train.next_batch(100)
        # print(x_train)
        print("image_shape:", np.shape(image))
        for i in range(3000):
            loss, _, y_predict_val, accuracy_val = sess.run([error, optimizer, y_predict, accuracy],
                                                            feed_dict={y_true: label, x_train: image})
            # print("y_predict:\n", sess.run(y_predict, feed_dict={y_true: label, x_train: image}))
            # print("第%d次迭代後:損失爲:%f, 準確率爲%f" % (i + 1, loss, accuracy_val))

        # 6、得到模型之後在測試集中進行驗證
        count = 0.0
        for i in range(100):
            x_test, y_test = mnist.test.next_batch(1)
            test_predict = tf.argmax(sess.run(y_predict, feed_dict={x_train: x_test, y_true: y_test}), 1).eval()
            test_true = tf.argmax(y_test, 1).eval()
            if test_true == test_predict:
                count += 1
            print("第%d次測試的預測值爲:%d, 真實值爲:%d" % (i+1, test_predict, test_true))
            # print("test_true:", test_true)
            # print("test_predict:", sess.run(y_predict, feed_dict={x_train: x_test, y_true: y_test}))
        print("在測試集上模型準確率爲:%f" % (count / 100))
    return None


if __name__ == "__main__":
    # file_name_list = os.listdir(r"E:\GameDownload\dataset_mnist")
    # # print(file_name_list)
    # file_list = [os.path.join(r"E:\GameDownload\dataset_mnist", file_name)
    #               for file_name in file_name_list if file_name[-4:] == "byte"]
    # # print(file_queue)
    mnist_recognition()

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章