Tensorflow(1):MNIST識別自己手寫的數字--入門篇(Softmax迴歸)

版權聲明:本文爲博主原創文章,未經博主允許不得轉載。 https://blog.csdn.net/u011389706/article/details/81223784

  機器學習入門都是從MNIST開始,Tensorflow官方社區提供了十分詳細的教程【MNIST機器學習入門】。但是我們顯然不滿足於僅僅把官方的代碼複製一遍然後輸出個結果,我們想能不能實現自己手寫數字的識別。
  本文作爲Tensorflow入門,結合官方代碼,利用Softmax迴歸函數,實現模型的訓練、保存、以及重新加載,完成對自己手寫數字的識別。

1.模型訓練及保存

  模型我們採用Softmax迴歸函數,具體代碼參考【MNIST機器學習入門】,這裏用梯度下降算法以0.01學習率最小化交叉熵對模型進行1000次訓練。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)  # 插入數據

# name在保存模型時非常有用
x = tf.placeholder("float", [None, 784], name='x') 
W = tf.Variable(tf.zeros([784, 10]), name='W')  
b = tf.Variable(tf.zeros([10]), name='b')
y = tf.nn.softmax(tf.matmul(x, W) + b, name='y')  # y預測概率分佈

y_ = tf.placeholder("float", [None, 10])  # y_實際概率分佈

cross_entropy = -tf.reduce_sum(y_ * tf.log(y))  # 交叉熵
# 梯度下降算法以0.01學習率最小化交叉熵
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)  
init = tf.initialize_all_variables()  # 初始化變量

sess = tf.Session()
sess.run(init)
saver = tf.train.Saver()

for i in range(1000):  # 開始訓練模型,循環1000次
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
saver.save(sess, 'minst_model.ckpt')  # 保存模型

  在代碼最前面,定義張量(變量)時,我們給每個張量(變量)都加了name關鍵字,這個對我們後期再次加載模型很重要。
  在最後,我們利用saver.save()函數,保存模型。模型名稱爲minst_model.ckpt。之後我們可以在文件夾下看到4個文件:

  • checkpoint: 保存目錄下所有模型文件列表
  • minst_model.ckpt.meta :保存了計算圖的結構,可以理解爲模型的結構
  • minst_model.ckpt.index 和 minst_model.ckpt.data-00000-of-00001:保存了模型中所有變量的值.

2.模型加載

   保存好模型之後,我們利用自己的圖片對模型進行測試。我們利用windows自帶的畫圖軟件,進行數字手寫,並保存成28*28像素的png圖片。例如0,1,2手寫體圖片,如下圖所示:


     

  整個測試代碼如下:

from PIL import Image, ImageFilter
import tensorflow as tf

def imageprepare():
    file_name = 'pic/2-3.png'  # 圖片路徑
    myimage = Image.open(file_name).convert('L')  # 轉換成灰度圖
    tv = list(myimage.getdata())  # 獲取像素值
    # 轉換像素範圍到[0 1], 0是純白 1是純黑
    tva = [(255-x)*1.0/255.0 for x in tv] 
    return tva

result = imageprepare()
init = tf.global_variables_initializer()
saver = tf.train.Saver

with tf.Session() as sess:
    sess.run(init)
    saver = tf.train.import_meta_graph('minst_model.ckpt.meta')  # 載入模型結構
    saver.restore(sess,  'minst_model.ckpt')  # 載入模型參數

    graph = tf.get_default_graph()  # 計算圖
    x = graph.get_tensor_by_name("x:0")  # 從模型中獲取張量x
    y = graph.get_tensor_by_name("y:0")  # 從模型中獲取張量y

    prediction = tf.argmax(y, 1)
    predint = prediction.eval(feed_dict={x: [result]}, session=sess)
    print(predint[0])

  在加載模型時,我們先用tf.train.import_meta_graph()載入模型的結構,之後利用saver.restore()加載模型的訓練好的參數。graph.get_tensor_by_name()依照名字(name)從模型中獲取張量。所以前面在保存模型時我們給每個張量和變量都加了name關鍵字
  關於如何保存和加載訓練模型可以參見博客【TensorFlow保存還原模型的正確方式】

3.識別結果

  輸出的識別結果如下所示:


    

     

    

  經測試,該方法基本識別率可以達到90%左右。所以基本可以滿足要求。

4.注意事項

  最早時,我手寫數字進行識別時,發現準確率很低。
  後來發現原因是:(1)我自己手動畫的數字線條太細了;(2)畫的有些數字在圖片中的位置沒有位於中心;(3)訓練集是西方的手寫數字,和中國的手寫數字習慣不同。下面是官方的訓練數據中的部分數字。


  在畫圖時,數字效果(畫筆粗細等)儘量和上面訓練集保持一致,就會得到較高的識別率!
  是以爲記!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章