基於Mnist數據集的單層神經元識別圖像

Mnist識別模糊手寫數字

一,導入mnist數據集

簡介mnist數據集(內含網盤數據集):https://blog.csdn.net/RObot_123/article/details/103220099

手動下載網址(官網):http://yann.lecun.com/exdb/mnist/
在這裏插入圖片描述

1.利用tensorflow下載mnist數據集

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

上面代碼能自動下載mnist數據集到代碼目錄的“MNIST_data”文件夾下

2.查看數據集裏的內容

print ('輸入數據打印:',mnist.train.images)
print ('輸入數據打印shape:',mnist.train.images.shape)

import pylab 
im = mnist.train.images[1]
im = im.reshape(-1,28)
pylab.imshow(im)
pylab.show()


print ('輸入數據打印shape:',mnist.test.images.shape)
print ('輸入數據打印shape:',mnist.validation.images.shape)

輸出信息如下:
在這裏插入圖片描述

序號 內容
1 解壓數據集
2 打印解壓的圖片信息
3 打印圖片shape
4 顯示訓練集中的圖-序號1
5 打印測試數據集與驗證數據shape

有關shape(形狀)的介紹:https://blog.csdn.net/RObot_123/article/details/103102627

二,分析mnist樣本特點定義變量

因爲 輸入的圖片是55000×784個矩陣
所以 創建一個**[None,784]的佔位符x和一個[None,10]的佔位符y**
最後 用feed機制將圖片和標籤輸入進去

import tensorflow as tf #導入tensorflow庫
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
import pylab 

tf.reset_default_graph()
# 定義佔位符
x = tf.placeholder(tf.float32, [None, 784]) # mnist data維度長度 28*28=784
y = tf.placeholder(tf.float32, [None, 10]) # 0-9 數字=> 10 種類別

三,構建模型

1.定義學習參數

  • 定義權重變量W
  • 定義偏值變量b
# 定義學習參數
W = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))

2.定義輸出節點

  • softmax分類
pred = tf.nn.softmax(tf.matmul(x, W) + b) # Softmax分類

3.定義反向傳播的結構

  • 損失函數:交叉熵函數
  • 設置學習率:0.01
  • 優化器:GradientDescentOptimizer(梯度下降算法)
# 損失函數
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1))

#參數設置
learning_rate = 0.01
# 使用梯度下降優化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

四,訓練模型並輸出中間狀態參數

  • 訓練次數(迭代次數):25
  • 設置批次量:100
  • 顯示步長:1
  • 啓用Session進行運算處理
training_epochs = 25
batch_size = 100
display_step = 1
#saver = tf.train.Saver()
#model_path = "log/521model.ckpt"

# 啓動session
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())# Initializing OP

    # 啓動循環開始訓練
    for epoch in range(training_epochs):
        avg_cost = 0.
        total_batch = int(mnist.train.num_examples/batch_size)
        # 遍歷全部數據集
        for i in range(total_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            # Run optimization op (backprop) and cost op (to get loss value)
            _, c = sess.run([optimizer, cost], feed_dict={x: batch_xs,
                                                          y: batch_ys})
            # Compute average loss
            avg_cost += c / total_batch
        # 顯示訓練中的詳細信息
        if (epoch+1) % display_step == 0:
            print ("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost))

    print( " Finished!")

輸出信息:
在這裏插入圖片描述

五,測試模型

  • 輸出(pred)與標籤(y)進行比較
  • reduce_mean對corrcet_prediction求平均值
    # 測試 model
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    # 計算準確率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print ("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))

模型準確率:
在這裏插入圖片描述

六,保存模型

  • 建議saver和路徑
  • 保存模型
saver = tf.train.Saver()
model_path = "log/mnisat_model.ckpt"

調用 saver

  	# Save model weights to disk
    save_path = saver.save(sess, model_path)
    print("Model saved in file: %s" % save_path)

輸出信息:
在這裏插入圖片描述
實際保存狀況:
在這裏插入圖片描述

七,讀取模型

首先註釋掉session會話後的代碼,然後將如下代碼添加到session裏去

#讀取模型
print("Starting 2nd session...")
with tf.Session() as sess:
    # Initialize variables
    sess.run(tf.global_variables_initializer())
    # Restore model weights from previously saved model
    saver.restore(sess, model_path)
    
     # 測試 model
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    # 計算準確率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print ("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
    
    output = tf.argmax(pred, 1)
    batch_xs, batch_ys = mnist.train.next_batch(2)
    outputval,predv = sess.run([output,pred], feed_dict={x: batch_xs})
    print(outputval,predv,batch_ys)

    im = batch_xs[0]
    im = im.reshape(-1,28)
    pylab.imshow(im)
    pylab.show()
    
    im = batch_xs[1]
    im = im.reshape(-1,28)
    pylab.imshow(im)
    pylab.show() 

在這裏插入圖片描述

八,完整代碼

1.驗證數據集(簡略)

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

print ('輸入數據打印:',mnist.train.images)
print ('輸入數據打印shape:',mnist.train.images.shape)

import pylab 
im = mnist.train.images[1]
im = im.reshape(-1,28)
pylab.imshow(im)
pylab.show()
print ('輸入數據打印shape:',mnist.test.images.shape)
print ('輸入數據打印shape:',mnist.validation.images.shape)

2.驗證數據集2(較詳細)

#導入mnist數據集
from tensorflow.examples.tutorials.mnist import input_data #從網上下載mnist數據集的模塊
mnist = input_data.read_data_sets('MNIST_data/',one_hot = False) #從指定文件夾導入數據集的數據
#分析mnist數據集
print('輸入訓練數據集數據:',mnist.train.images) #打引導如數據集的數據
print('輸入訓練數據集shape:',mnist.train.images.shape) #打印訓練數據集的形狀
print('輸入測試數據集shape:',mnist.test.images.shape) #用於評估訓練過程中的準確度
print('輸入驗證數據集shape:',mnist.validation.images.shape) #用於評估最終模型的準確度
print('輸入標籤的shape:',mnist.train.labels.shape)
#展示mnist數據集
import pylab 
im = mnist.test.images[6] #train中的第六張圖
im = im.reshape(-1,28)
pylab.imshow(im)
pylab.show()

3.識別數據集模糊手寫數字

import tensorflow as tf #導入tensorflow庫
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
import pylab 

tf.reset_default_graph()
# tf Graph Input
x = tf.placeholder(tf.float32, [None, 784]) # mnist data維度 28*28=784
y = tf.placeholder(tf.float32, [None, 10]) # 0-9 數字=> 10 classes

# Set model weights
W = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))

# 構建模型
pred = tf.nn.softmax(tf.matmul(x, W) + b) # Softmax分類

# Minimize error using cross entropy
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1))

#參數設置
learning_rate = 0.01
# 使用梯度下降優化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

training_epochs = 25
batch_size = 100
display_step = 1
saver = tf.train.Saver()
model_path = "log/mnist_model.ckpt"

# 啓動session
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())# Initializing OP

    # 啓動循環開始訓練
    for epoch in range(training_epochs):
        avg_cost = 0.
        total_batch = int(mnist.train.num_examples/batch_size)
        # 遍歷全部數據集
        for i in range(total_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            # Run optimization op (backprop) and cost op (to get loss value)
            _, c = sess.run([optimizer, cost], feed_dict={x: batch_xs,
                                                          y: batch_ys})
            # Compute average loss
            avg_cost += c / total_batch
        # 顯示訓練中的詳細信息
        if (epoch+1) % display_step == 0:
            print ("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost))

    print( " Finished!")

    # 測試 model
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    # 計算準確率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print ("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))

    # Save model weights to disk
    save_path = saver.save(sess, model_path)
    print("Model saved in file: %s" % save_path)



##讀取模型
#print("Starting 2nd session...")
#with tf.Session() as sess:
#    # Initialize variables
#    sess.run(tf.global_variables_initializer())
#    # Restore model weights from previously saved model
#    saver.restore(sess, model_path)
#    
#     # 測試 model
#    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
#    # 計算準確率
#    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
#    print ("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
#    
#    output = tf.argmax(pred, 1)
#    batch_xs, batch_ys = mnist.train.next_batch(2)
#    outputval,predv = sess.run([output,pred], feed_dict={x: batch_xs})
#    print(outputval,predv,batch_ys)
#
#    im = batch_xs[0]
#    im = im.reshape(-1,28)
#    pylab.imshow(im)
#    pylab.show()
#    
#    im = batch_xs[1]
#    im = im.reshape(-1,28)
#    pylab.imshow(im)
#    pylab.show()



上文若有任何錯誤或不妥歡迎指出,謝謝!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章