Tensorflow實現MNIST手寫識別

MNIST手寫體識別訓練和測試模型下載地址:
MNIST手寫體模型下載

MNIST手寫體識別,標籤編碼爲獨熱(one-hot)編碼

One-Hot編碼,又稱爲一位有效編碼,主要是採用N位狀態寄存器來對N個狀態進行編碼,每個狀態都由他獨立的寄存器位,並且在任意時候只有一位有效。
One-Hot編碼是分類變量作爲二進制向量的表示。這首先要求將分類值映射到整數值。然後,每個整數值被表示爲二進制向量,除了整數的索引之外,它都是零值,它被標記爲1。

導入相關包

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import matplotlib.pyplot as plt
import numpy as np

numpy安裝:
pip install numpy
matplotlib安裝:
pip install matplotlib

MNIST圖像讀取

	mnist = input_data.read_data_sets("data/MNIST/", one_hot=True)
	# mnist 中每張圖片共有28*28=784個像素點

變量定義

	x = tf.placeholder(tf.float32, [None, 784], name='x')
    # 0-9 一共十個數字-》十個類別
    y = tf.placeholder(tf.float32, [None, 10], name='y')
    # 定義變量
    w = tf.Variable(tf.zeros([784.10]), name='w')
    b = tf.Variable(tf.zeros([10]), name='b')
    # 使用單個神經元,進行前向計算
    forward = tf.matmul(x, w) + b
    # 使用softmax對結果集進行分類
    pred = tf.nn.softmax(forward)

    # 訓練次數
    train_epochs = 50
    # 單次訓練樣本數(批次大小)
    batch_size = 10
    # 一輪訓練有多少批次
    total_batch = int(mnist.train.num_examples / batch_size)
    learning_rate = 0.01
    # 顯示粒度
    display_step = 1

定義損失函數和優化器

	# 定義交叉熵損失函數
    loss_function = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1))

    # 定義優化器,梯度下降
    optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)

定義準確率

    # 檢查預測類別tf.argmax(pred,1) 與實際類別tf.argmax(y,1)的匹配情況
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))

    # 準確率,將布爾值轉化爲浮點數,並計算平均值
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

定義Tensorflow會話

	sess = tf.Session()
    init = tf.global_variables_initializer()
    sess.run(init)

模型訓練

    for epoch in range(train_epochs):
        for batch in range(total_batch):
            # 讀取批次數據
            xs, ys = mnist.train.next_batch(batch_size)
            # 執行批次訓練
            sess.run(optimizer, feed_dict={x: xs, y: ys})
        # total_batch個批次訓練完成後,使用驗證數據計算誤差與準確率,驗證集未分批
        loss, acc = sess.run([loss_function, accuracy],
                             feed_dict={x: mnist.validation.images, y: mnist.validation.labels})
        # 打印訓練過程中的詳細信息
        if (epoch + 1) % display_step == 0:
            print("Train Epoch:", '%02d' % (epoch + 1), 'Loss=', '{:.9f}'.format(loss), 'Accuracy=',
                  '{:.4f}'.format(acc))

圖像可視化函數

def plot_images_labels_prediction(images,  # 圖像列表
                                  labels,  # 標籤列表
                                  prediction,  # 預測值列表
                                  index,  # 從第index個開始顯示
                                  num=10):  # 缺省一次顯示10幅
    fig = plt.gcf()  # 獲取當前圖標,Get Current Figure
    fig.set_size_inches(10, 12)  # 1英寸等於2.54cm
    if num > 25:
        num = 25  # 最多顯示25個子圖
    for i in range(0, num):
        ax = plt.subplot(5, 5, i + 1)  # 獲取當前要處理的子圖
        # 顯示第index個圖像
        ax.imshow(np.reshape(images[index], (28, 28)), cmap='binary')

        # 構建該圖上要顯示的title
        title = "label=" + str(np.argmax(labels[index]))
        if len(prediction) > 0:
            title += ",predict=" + str(prediction[index])

        # 顯示圖上的title信息
        ax.set_title(title, fontsize=10)
        # 不限是座標軸
        ax.set_xticks([])
        ax.set_yticks([])
        index += 1

    plt.show()

該過程代碼基於Tensorflow 1.0完成,Tensorflow 1.0安裝:

  1. 通過Anaconda完成安裝:
	# 創建名稱爲tf-1.0的conda虛擬Python環境,並指定Python版本爲3.5
	conda create -n tf-1.0 python=3.5
	# 激活tf-1.0環境
	conda activate tf-1.0
	# 查找tensorflow版本號
	conda search tensorflow
	# 安裝指定版本的tensorflow
	conda install tensorflow=1.9
  1. 通過pip安裝:
	# 安裝指定版本的tensorflow,默認安裝tensorflow - 2.0
	pip install tensorflow==1.9
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章