python 簡單使用MNIST數據集實現手寫數字識別

一、瞭解MNIST數據集

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import matplotlib.pyplot as plt
import numpy as np

#0 讀取mnist數據集
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

#1 讀取訓練集的一張圖片與標籤
batch_xs, batch_ys = mnist.train.next_batch(1)
img = tf.reshape(batch_xs,[28,28])
label = tf.argmax(batch_ys,1)

#2 讀取測試集的一張圖片與標籤
img_tt = mnist.test.images[0]
img_t = tf.reshape(img_tt,[28,28])
label_tt = mnist.test.labels[0]
label_tt = np.reshape(label_tt,[1,10])
label_t = tf.argmax(label_tt,1)

with tf.Session() as sess:
    # 訓練集 打印標籤 像素值 
    print("訓練集:%d"%sess.run(label))
    im = sess.run(img) 
    for i in range(28):
        for j in range(28):
            print("%d "%round(im[i][j]),end='')
        print()    
    
    #測試集 打印標籤 像素值 
    print("測試集:%d"%sess.run(label_t))
    im_t = sess.run(img_t) 
    for i in range(28):
        for j in range(28):
            print("%d "%round(im_t[i][j]),end='')
        print()
        
    #顯示圖片
    plt.figure()
    plt.subplot(121)
    plt.imshow(im,cmap = 'gray')
    plt.subplot(122)
    plt.imshow(im_t,cmap = 'gray')    
    plt.show()
    

輸出結果如下: 

二、 簡單的訓練、測試和保存模型

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data

#1 讀取訓練數據
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

#2 建立模型
x = tf.placeholder(tf.float32, [None, 784])
w = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y=tf.nn.softmax(tf.matmul(x,w)+b)

#3 損失函數   交叉熵
y_ = tf.placeholder("float", [None,10])
cross_entropy = -tf.reduce_sum(y_*tf.log(y))

#4 優化訓練
optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

#5 開啓會話
with tf.Session() as sess:
    # 初始化
    sess.run(tf.global_variables_initializer())

    # 訓練
    for i in range(500):
        batch_xs, batch_ys = mnist.train.next_batch(100)
        sess.run(optimizer, feed_dict={x: batch_xs, y_: batch_ys})

    # 測試模型
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    accuracy = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}) * 100
    print(accuracy)

    # 保存模型
    model_path = "model/"
    model_name = "model_" + (str(accuracy))[:4] + "%"
    tf.train.Saver().save(sess,model_path+model_name)

三、簡單的使用模型

import tensorflow as tf

#1 模型參數
x = tf.placeholder(tf.float32, [None, 784])
w = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y=tf.nn.softmax(tf.matmul(x,w)+b)
#y_ = tf.placeholder("float", [None,10])

#2 讀取圖片
img = tf.read_file('7.jpg')                                 #讀取 彩色圖片
im_3 = tf.image.decode_jpeg(img, channels=3)                #解碼 
im_resize = tf.image.resize_images(im_3,[28,28])            #縮放成28X28
im_gry = tf.squeeze(tf.image.rgb_to_grayscale(im_resize),2) #灰度化  降維變成二維
im_reshape = tf.reshape(im_gry,[1,784])                     #改變形狀 

#3 開啓會話
with tf.Session() as sess:
    # 讀取模型
    tf.train.Saver().restore(sess,"model/model_91.4%")
    
    # 輸入待檢測圖像數據
    xx = sess.run(im_reshape)

    # 進行識別並輸出
    result = sess.run(tf.argmax(y, 1), feed_dict={x: xx})
    print(result)
    

 結果:                   結果:

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章