機器學習入門都是從MNIST開始,Tensorflow官方社區提供了十分詳細的教程【MNIST機器學習入門】。但是我們顯然不滿足於僅僅把官方的代碼複製一遍然後輸出個結果,我們想能不能實現自己手寫數字的識別。
本文作爲Tensorflow入門,結合官方代碼,利用Softmax迴歸函數,實現模型的訓練、保存、以及重新加載,完成對自己手寫數字的識別。
1.模型訓練及保存
模型我們採用Softmax迴歸函數,具體代碼參考【MNIST機器學習入門】,這裏用梯度下降算法以0.01學習率最小化交叉熵對模型進行1000次訓練。
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True) # 插入數據
# name在保存模型時非常有用
x = tf.placeholder("float", [None, 784], name='x')
W = tf.Variable(tf.zeros([784, 10]), name='W')
b = tf.Variable(tf.zeros([10]), name='b')
y = tf.nn.softmax(tf.matmul(x, W) + b, name='y') # y預測概率分佈
y_ = tf.placeholder("float", [None, 10]) # y_實際概率分佈
cross_entropy = -tf.reduce_sum(y_ * tf.log(y)) # 交叉熵
# 梯度下降算法以0.01學習率最小化交叉熵
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
init = tf.initialize_all_variables() # 初始化變量
sess = tf.Session()
sess.run(init)
saver = tf.train.Saver()
for i in range(1000): # 開始訓練模型,循環1000次
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
saver.save(sess, 'minst_model.ckpt') # 保存模型
在代碼最前面,定義張量(變量)時,我們給每個張量(變量)都加了name關鍵字,這個對我們後期再次加載模型很重要。
在最後,我們利用saver.save()函數,保存模型。模型名稱爲minst_model.ckpt。之後我們可以在文件夾下看到4個文件:
- checkpoint: 保存目錄下所有模型文件列表
- minst_model.ckpt.meta :保存了計算圖的結構,可以理解爲模型的結構
- minst_model.ckpt.index 和 minst_model.ckpt.data-00000-of-00001:保存了模型中所有變量的值.
2.模型加載
保存好模型之後,我們利用自己的圖片對模型進行測試。我們利用windows自帶的畫圖軟件,進行數字手寫,並保存成28*28像素的png圖片。例如0,1,2手寫體圖片,如下圖所示:
整個測試代碼如下:
from PIL import Image, ImageFilter
import tensorflow as tf
def imageprepare():
file_name = 'pic/2-3.png' # 圖片路徑
myimage = Image.open(file_name).convert('L') # 轉換成灰度圖
tv = list(myimage.getdata()) # 獲取像素值
# 轉換像素範圍到[0 1], 0是純白 1是純黑
tva = [(255-x)*1.0/255.0 for x in tv]
return tva
result = imageprepare()
init = tf.global_variables_initializer()
saver = tf.train.Saver
with tf.Session() as sess:
sess.run(init)
saver = tf.train.import_meta_graph('minst_model.ckpt.meta') # 載入模型結構
saver.restore(sess, 'minst_model.ckpt') # 載入模型參數
graph = tf.get_default_graph() # 計算圖
x = graph.get_tensor_by_name("x:0") # 從模型中獲取張量x
y = graph.get_tensor_by_name("y:0") # 從模型中獲取張量y
prediction = tf.argmax(y, 1)
predint = prediction.eval(feed_dict={x: [result]}, session=sess)
print(predint[0])
在加載模型時,我們先用tf.train.import_meta_graph()載入模型的結構,之後利用saver.restore()加載模型的訓練好的參數。graph.get_tensor_by_name()依照名字(name)從模型中獲取張量。所以前面在保存模型時我們給每個張量和變量都加了name關鍵字。
關於如何保存和加載訓練模型可以參見博客【TensorFlow保存還原模型的正確方式】
3.識別結果
輸出的識別結果如下所示:
經測試,該方法基本識別率可以達到90%左右。所以基本可以滿足要求。
4.注意事項
最早時,我手寫數字進行識別時,發現準確率很低。
後來發現原因是:(1)我自己手動畫的數字線條太細了;(2)畫的有些數字在圖片中的位置沒有位於中心;(3)訓練集是西方的手寫數字,和中國的手寫數字習慣不同。下面是官方的訓練數據中的部分數字。
在畫圖時,數字效果(畫筆粗細等)儘量和上面訓練集保持一致,就會得到較高的識別率!
是以爲記!