MNIST數據集數字識別(一)

MNIST數據集數字識別

這裏是新入門的感知機識別數字的代碼詳解

感知機(Perceptron)實現MNIST數字識別

在 jupyter notebook 中進行實現。網絡結構是,具體代碼如下

1. 導入包

// 導入numpy; tensorflow; input_data包
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

2. 載入MNIST數據集,並創建默認的Interactive Session

// 載入MNIST數據集,one_hot=True表示一個長度爲n的數組只有一個元素是1.0,其它元素是0.0;而非one_hot標籤類似0 1 2 3 … n
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
// 創建默認的Interactive Session
sess = tf.InteractiveSession()

3. 感知機網絡結構

// 聲明輸入張量的格式,具體的數值在正式運行時給出 None 表示數據的行數不確定。也就是每個圖像是個784的列的向量,不確定有多少張圖像
x = tf.placeholder(tf.float32, [None, 784]) 
//784行,10列的二維數組中的所有值初始化爲0,W 表示權重(也就是重要程度,越重要,權重的值越大)
W = tf.Variable(tf.zeros([784, 10]))
// 初始化一個含有100個值的一維數組中,全部初始化爲0
b = tf.Variable(tf.zeros([10]))
// 輸出神經元,輸出矩陣 x 和 W 的乘積加上偏置 b
y = tf.nn.softmax(tf.matmul(x, W)+b)

4. 損失和優化器

交叉熵損失函數
在這裏插入圖片描述

// 真實的y標籤
y_ = tf.placeholder(tf.float32, [None, 10])
// 損失loss,此處用的爲交叉熵損失函數,多輸入單輸出。
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
// 優化器。用隨機梯度下降算法尋找最優點。但是SGD容易陷入局部最優,學習率爲0.5
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

5. 初始化和準確率

// 將所有變量初始化,並直接執行 run() 方法
tf.global_variables_initializer().run()
// tf.argmax(data,axis)是返回一維數組中張量最大的值所在的位置,axis=0,按列計算每列最大數的下標。axis=1,按行計算。
// tf.equal返回的是長度爲100(因爲每個批次有100條樣本)的一維數組,內容是布爾值true或者false
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
// tf.cast(data,dtype)的作用是將data的類型轉換爲dtype類型
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
// 比如這裏,把bool類型的correct_prediction轉換成tf.float32,就實現了true或者false變成了0或者1的轉換

6. 訓練

// 訓練階段,迭代10000for i in range(1000):
   // 每次隨機從訓練集中抽取100條樣本構成一個mini-batch
   batch_xs, batch_ys = mnist.train.next_batch(100)
   // train_step.run帶着有實際輸入的x和y_執行train_step這個operation
   train_step.run({x: batch_xs, y_: batch_ys})
   // 這裏使用的是eval,和run的區別是eval只能接收一個operation返回一個tensor結果
   // 這裏是將mnist.train.images和mnist.train.labels作爲x和y的值feed給了accuracy這個operation
   train_accuracy = accuracy.eval({x: mnist.train.images, y_: mnist.train.labels})
   print("Step%d, Training accuracy %g" % (i, train_accuracy))
print("準確率:",accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))

7. 注意

若在訓練過程中有input_data的警告,加入以下代碼
import logging
// 下面的類用於解決read_data_sets拋出的警告
class WarningFilter(logging.Filter):
   def filter(self, record):
       msg = record.getMessage()
       tf_warning = 'retry (from tensorflow.contrib.learn.python.learn.datasets.base)' or 'from tensorflow.contrib.learn.python.learn.datasets.base' in msg
       return not tf_warning
           
logger = logging.getLogger('tensorflow')
logger.addFilter(WarningFilter())

PS: 初學圖像處理,還是小白,大家有問題多多交流。文中代碼借鑑官網及部分博客,有問題還請指正,十分感謝!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章