mnist手写体识别

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# import os


def mnist_recognition():
    """
    使用全连接进行手写体识别
    :return:
    """
    # 1、准备数据
    #    两种数据读取方式:
    #   (1)、QueueRunner
    #   (2)、Feeding
    mnist = input_data.read_data_sets(r"E:\GameDownload\dataset_mnist", one_hot=True)
    x_train = tf.placeholder(dtype=tf.float32, shape=[None, 784])
    y_true = tf.placeholder(dtype=tf.float32, shape=[None, 10])
    # 2、构建全连接模型(注意模型参数应用变量存储)
    Weights = tf.Variable(initial_value=tf.random.normal(shape=[784, 10]))
    bias = tf.Variable(initial_value=tf.random.normal(shape=[10]))
    y_predict = tf.matmul(x_train, Weights) + bias
    # print(y_predict)
    # 3、构造损失函数(用softmax表示的交叉熵)
    error = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_predict))
    # 4、优化损失(使用梯度下降方法)
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(error)
    # 5、计算准确率,对y_predict使用argmax可以找出其一行中最大值所在的列
    #    由于使用的是one-hot编码,所以预测值与真实值在编码内的位置相同时为true,否则为false
    #    之后将bool值转为浮点数后求均值,即为一个batch内true的机率
    equal_list = tf.equal(tf.argmax(y_predict, 1), tf.argmax(y_true, 1))
    accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32))
    init = tf.global_variables_initializer()
    with tf.compat.v1.Session() as sess:
        sess.run(init)
        image, label = mnist.train.next_batch(100)
        # print(x_train)

        for i in range(1000):
            loss, _, y_predict_val, accuracy_val = sess.run([error, optimizer, y_predict, accuracy],
                                                            feed_dict={y_true: label, x_train: image})
            # print("y_predict:\n", sess.run(y_predict, feed_dict={y_true: label, x_train: image}))
            # print("第%d次迭代后:损失为:%f, 准确率为%f" % (i + 1, loss, accuracy_val))

        # 6、得到模型之后在测试集中进行验证
        count = 0.0
        for i in range(100):
            x_test, y_test = mnist.test.next_batch(1)
            test_predict = tf.argmax(sess.run(y_predict, feed_dict={x_train: x_test, y_true: y_test}), 1).eval()
            test_true = tf.argmax(y_test, 1).eval()
            if test_true == test_predict:
                count += 1
            print("第%d次测试的预测值为:%d, 真实值为:%d" % (i+1, test_predict, test_true))
            # print(test_true)
        print("在测试集上模型准确率为:%f" % (count / 100))
    return None


if __name__ == "__main__":
    # file_name_list = os.listdir(r"E:\GameDownload\dataset_mnist")
    # # print(file_name_list)
    # file_list = [os.path.join(r"E:\GameDownload\dataset_mnist", file_name)
    #               for file_name in file_name_list if file_name[-4:] == "byte"]
    # # print(file_queue)
    mnist_recognition()

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章