MNIST手寫體識別訓練和測試模型下載地址:
MNIST手寫體模型下載
MNIST手寫體識別,標籤編碼爲獨熱(one-hot)編碼
One-Hot編碼,又稱爲一位有效編碼,主要是採用N位狀態寄存器來對N個狀態進行編碼,每個狀態都由他獨立的寄存器位,並且在任意時候只有一位有效。
One-Hot編碼是分類變量作爲二進制向量的表示。這首先要求將分類值映射到整數值。然後,每個整數值被表示爲二進制向量,除了整數的索引之外,它都是零值,它被標記爲1。
導入相關包
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import matplotlib.pyplot as plt
import numpy as np
numpy安裝:
pip install numpy
matplotlib安裝:
pip install matplotlib
MNIST圖像讀取
mnist = input_data.read_data_sets("data/MNIST/", one_hot=True)
# mnist 中每張圖片共有28*28=784個像素點
變量定義
x = tf.placeholder(tf.float32, [None, 784], name='x')
# 0-9 一共十個數字-》十個類別
y = tf.placeholder(tf.float32, [None, 10], name='y')
# 定義變量
w = tf.Variable(tf.zeros([784.10]), name='w')
b = tf.Variable(tf.zeros([10]), name='b')
# 使用單個神經元,進行前向計算
forward = tf.matmul(x, w) + b
# 使用softmax對結果集進行分類
pred = tf.nn.softmax(forward)
# 訓練次數
train_epochs = 50
# 單次訓練樣本數(批次大小)
batch_size = 10
# 一輪訓練有多少批次
total_batch = int(mnist.train.num_examples / batch_size)
learning_rate = 0.01
# 顯示粒度
display_step = 1
定義損失函數和優化器
# 定義交叉熵損失函數
loss_function = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1))
# 定義優化器,梯度下降
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)
定義準確率
# 檢查預測類別tf.argmax(pred,1) 與實際類別tf.argmax(y,1)的匹配情況
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
# 準確率,將布爾值轉化爲浮點數,並計算平均值
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
定義Tensorflow會話
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
模型訓練
for epoch in range(train_epochs):
for batch in range(total_batch):
# 讀取批次數據
xs, ys = mnist.train.next_batch(batch_size)
# 執行批次訓練
sess.run(optimizer, feed_dict={x: xs, y: ys})
# total_batch個批次訓練完成後,使用驗證數據計算誤差與準確率,驗證集未分批
loss, acc = sess.run([loss_function, accuracy],
feed_dict={x: mnist.validation.images, y: mnist.validation.labels})
# 打印訓練過程中的詳細信息
if (epoch + 1) % display_step == 0:
print("Train Epoch:", '%02d' % (epoch + 1), 'Loss=', '{:.9f}'.format(loss), 'Accuracy=',
'{:.4f}'.format(acc))
圖像可視化函數
def plot_images_labels_prediction(images, # 圖像列表
labels, # 標籤列表
prediction, # 預測值列表
index, # 從第index個開始顯示
num=10): # 缺省一次顯示10幅
fig = plt.gcf() # 獲取當前圖標,Get Current Figure
fig.set_size_inches(10, 12) # 1英寸等於2.54cm
if num > 25:
num = 25 # 最多顯示25個子圖
for i in range(0, num):
ax = plt.subplot(5, 5, i + 1) # 獲取當前要處理的子圖
# 顯示第index個圖像
ax.imshow(np.reshape(images[index], (28, 28)), cmap='binary')
# 構建該圖上要顯示的title
title = "label=" + str(np.argmax(labels[index]))
if len(prediction) > 0:
title += ",predict=" + str(prediction[index])
# 顯示圖上的title信息
ax.set_title(title, fontsize=10)
# 不限是座標軸
ax.set_xticks([])
ax.set_yticks([])
index += 1
plt.show()
該過程代碼基於Tensorflow 1.0完成,Tensorflow 1.0安裝:
- 通過Anaconda完成安裝:
# 創建名稱爲tf-1.0的conda虛擬Python環境,並指定Python版本爲3.5 conda create -n tf-1.0 python=3.5 # 激活tf-1.0環境 conda activate tf-1.0 # 查找tensorflow版本號 conda search tensorflow # 安裝指定版本的tensorflow conda install tensorflow=1.9
- 通過pip安裝:
# 安裝指定版本的tensorflow,默認安裝tensorflow - 2.0 pip install tensorflow==1.9