【tensorflow】識別圖中模糊手寫數字

1、導入NMIST數據集。

手動下載:http://yann.lecun.com/exdb/mnist/

自動下載:

from tensorflow.examples.tutorials.mnist import input_data
minst=input_data.read_data_sets("MNIST_data/",one_hot)

 

2、分析MNIST樣本特點定義變量。

由於輸入圖片是個550000*784的矩陣,所以先創建一個[None,784]的佔位符x和[None,10]的佔位符y,然後使用feed機制將圖片和標籤輸入進去。

import tensorflow as tf
from tensorflow.examples.tutorials.minst import input_data
mnist=input_data.read_data_sets("MNIST_data/",one_hot=True)
import pylab

tf.reset_default_graph()
#定義佔位符
x=tf.placeholder(tf.float32,[None,784])# mnist data維度 28*28=784
y=tf.placeholder(tf.flloat32,[None,10])## 0-9 數字=> 10 classes

3、構建模型。

①定義學習參數:使用Variable定義學習參數。

w=tf.Variable(tf.random_normal([784,10]))
b=tf.Variable(tf.zeros([10]))

 

②定義輸出節點

pred=tf.nn.softmax(tf.matul(x,w)+b)

這裏的x是一個二維張量,擁有多個輸入。然後在加上b,把它們的和輸入到tf.nn.softmax函數裏。

softmax詳解參考:https://blog.csdn.net/bitcarmanlee/article/details/82320853?depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromBaidu-1&utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromBaidu-1

③定義反向傳播的結構

#損失函數
cost=tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1))
#定義參數
learning_rate=0.01
#使用梯度下降優化器
optimizer=tf.train.GradientDescentOptimzer(learning_rate).minimize(cost)

整個過程就是不斷讓損失值cost變小,因爲損失值越小,才能表明輸出的結果跟標籤的數據越相近。當cost小到我們的需求時,這時的b和w就是訓練出來的合適值。

 

4、訓練模型並輸出中間狀態參數。

training_epochs=25 #整個訓練迭代25次
batch_size=100 #訓練過程中一次取100條數據進行訓練
display_step=1#每訓練一次就把具體的中間狀態顯示出來

#啓動session
with tf.Session() as sess:
    see.run(tf.global_variables_initializer())

    #啓動循環開始訓練
    for epoch in rande(traning_epochs):
        avg_cost=0
        total_batch = int(mnist.train.num_examples/batch_size)
        # 遍歷全部數據集
        for i in range(total_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            # 運行優化器
            c = sess.run([optimizer, cost], feed_dict={x: batch_xs,y: batch_ys})
            # 計算loss值
            avg_cost += c / total_batch
        # 顯示訓練中的詳細信息
        if (epoch+1) % display_step == 0:
            print ("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost))

    print( " Finished!")

5、測試模型

測試錯誤率的算法:直接判斷預測結果與真實標籤是否相同,如果相同,就表明是正確的,如果不相同,就表示是錯誤的。然後正確的個數除以總個數,得到的即爲正確率。

correct_prediction=tf.equal(tf.argmax(pred,1),tf.argmax(y, 1)))
    #計算準確率
    accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32)))
    print("Accuracy:",accuracy.eval(({x: mnist.test.images, y: mnist.test.labels}))))

6、保存模型

#保存模型
save_path=saver.save(sess,model_path)
print("Model saved in file:%s"%save_path)

7、讀取模型

with tf.Session() as sess:
    #初始化變量
    sess.run(tf.global_variables_initializer())
    #恢復模型變量
    saver.restre(sess,model_path)

    # 測試 model
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    # 計算準確率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print ("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
    
    output=tf.argmax(pred,1)
    batch_xs,batch_ys=mnist.train.next_batch(2)
    outputval,predv=see.run([output,pred],feed_dict={x: batch_xs})
    print(outputval,predv,batch_ys)

    im=batch_xs[0]
    im=im.reshape(-1,28)
    pylab.imshow(im)
    pylab.show()

完整代碼

import tensorflow as tf #導入tensorflow庫
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
import pylab 

tf.reset_default_graph()
# tf Graph Input
x = tf.placeholder(tf.float32, [None, 784]) # mnist data維度 28*28=784
y = tf.placeholder(tf.float32, [None, 10]) # 0-9 數字=> 10 classes

# Set model weights
W = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))

# 構建模型
pred = tf.nn.softmax(tf.matmul(x, W) + b) # Softmax分類

# Minimize error using cross entropy
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1))

#參數設置
learning_rate = 0.01
# 使用梯度下降優化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

training_epochs = 25
batch_size = 100
display_step = 1
saver = tf.train.Saver()
model_path = "log/521model.ckpt"

# 啓動session
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())# Initializing OP

    # 啓動循環開始訓練
    for epoch in range(training_epochs):
        avg_cost = 0.
        total_batch = int(mnist.train.num_examples/batch_size)
        # 遍歷全部數據集
        for i in range(total_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            # Run optimization op (backprop) and cost op (to get loss value)
            _, c = sess.run([optimizer, cost], feed_dict={x: batch_xs,
                                                          y: batch_ys})
            # Compute average loss
            avg_cost += c / total_batch
        # 顯示訓練中的詳細信息
        if (epoch+1) % display_step == 0:
            print ("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost))

    print( " Finished!")

    # 測試 model
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    # 計算準確率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print ("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))

    # Save model weights to disk
    save_path = saver.save(sess, model_path)
    print("Model saved in file: %s" % save_path)



#讀取模型
print("Starting 2nd session...")
with tf.Session() as sess:
    # Initialize variables
    sess.run(tf.global_variables_initializer())
    # Restore model weights from previously saved model
    saver.restore(sess, model_path)
    
     # 測試 model
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    # 計算準確率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print ("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
    
    output = tf.argmax(pred, 1)
    batch_xs, batch_ys = mnist.train.next_batch(2)
    outputval,predv = sess.run([output,pred], feed_dict={x: batch_xs})
    print(outputval,predv,batch_ys)

    im = batch_xs[0]
    im = im.reshape(-1,28)
    pylab.imshow(im)
    pylab.show()
    
    im = batch_xs[1]
    im = im.reshape(-1,28)
    pylab.imshow(im)
    pylab.show()

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章