TensorFlow實例(5.1)--MNIST手寫數字進階算法(卷積神經網絡CNN)


MNIST 是Tensoflow提供的一個入門級的計算機視覺數據集,分爲兩部分(訓練集和測試集
其中訓練集共55000張,測試集共10000張,當爲None時隨機讀取 

在看這篇文章前,你必須先對神經網絡(NN)、MNIST手寫數字是什麼有初步的瞭解

關於神經網絡,你可能參考   機器學習(1)--神經網絡初探
關MNIST手寫數字 你可以參考   TensorFlow實例(4)--MNIST簡介及手寫數字分類算法

下面簡單介紹一下卷積神經網絡CNN,
圖片也是網上流行的圖片,我只做了一些簡單的修改,


同時我又分出了兩篇文章,闡述卷積(Convolution)、最大池化(MaxPooling)

TensorFlow實例(5.2)--MNIST手寫數字進階算法(卷積神經網絡CNN) 之 卷積tf.nn.conv2d

TensorFlow實例(5.3)--MNIST手寫數字進階算法(卷積神經網絡CNN) 之 最大池化tf.nn.max_pool


另外配有  三個文章py文件及MNIST集的下載,點擊下載

特別注意:這兩篇文章旨在說明這兩條指令數據演變,數據是另外建立,和這篇文章的數據沒有任何關係

本例代碼分以下幾部份
1、初始化數據及設置輸入圖像模型
2、構建兩次卷積、最大池化模型
3、兩次全連接模型
4、建立訓練、判斷正確、統計模型  及 訓練與測試


一、初始化數據及設置輸入圖像模型

# -*- coding:utf-8 -*-
import tensorflow as tf 
import tensorflow.examples.tutorials.mnist.input_data as input_data
import random

#讀取mnist數據,下載後的Mnist並解壓後,放在項目的同級目錄下,通過下面程序即可讀取
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
#在訓練時,None表示從訓練集中取得一張圖表(x_data),及圖表的值(y_data)
#在測試評估模型時,None表示整個測試集合
x_data = tf.placeholder("float", [None, 784]) 
y_data = tf.placeholder("float", [None,10])

二、構建兩次卷積、最大池化模型
如果你對[5,5,1,32],,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME" 這些參數設置不瞭解

TensorFlow實例(5.2)--MNIST手寫數字進階算法(卷積神經網絡CNN) 之 卷積tf.nn.conv2d
TensorFlow實例(5.3)--MNIST手寫數字進階算法(卷積神經網絡CNN) 之 最大池化tf.nn.max_pool

以下對指令做些簡述
tf.reshape(x_data,[-1,28,28,1]) 將 圖像轉爲了28*28二維數組,第一維的-1爲batch相當於x_data中的None, 第四維的1表示1通道
tf.nn.conv2d ,卷積,卷積後將變爲  batch * 28 * 28 * 32 通道 的數組
tf.nn.relu  你可以簡單的理解爲,對於數組中-1的值設爲0,
tf.nn.max_pool 最大池化,ksize=[1,2,2,1],strides=[1,2,2,1] 簡單的說圖像變一半大

兩次的卷積,
第一次以圖像輸入 28*28*1通道  ,以 14*14*32通道  輸出
第二次以第一次的圖輸出 14*14*32 通道 爲輸入,  以 7*7*64通道 輸出

#第一層卷積與最大池化
w1 = tf.Variable(tf.truncated_normal([5,5,1,32],stddev=0.1)) #建立權重weight
b1 = tf.Variable(tf.constant(0.1,shape=[32])) 
h1 = tf.nn.relu(tf.nn.conv2d(tf.reshape(x_data,[-1,28,28,1]) ,w1,strides=[1,1,1,1],padding="SAME") + b1)
p1 = tf.nn.max_pool(h1,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")

#第二層卷積與最大池化
w2 = tf.Variable(tf.truncated_normal([5,5,32,64],stddev=0.1))
b2 = tf.Variable(tf.constant(0.1,shape=[64]))
h2 = tf.nn.relu(tf.nn.conv2d(p1 ,w2,strides=[1,1,1,1],padding="SAME") + b2)
p2 = tf.nn.max_pool(h2,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")

三、兩次全連接模型
兩次全連接
第一次,以卷積輸出 7*7*64通道  全連接爲 連接成爲一個一維向量1024個數據,
第二次,以第一次的一維向量1024個元素輸入,全連接爲10個元素,
10個元素就對應,輸出的每一維都是圖片y屬於該類別的概率。


如果不明白這個10個元素

可以參考 TensorFlow實例(4)--MNIST簡介及手寫數字分類算法


Dropout作用其在訓練階段阻止神經元的共適應,本文不對這個做細論,有興趣可以百度一下。


#緊密連接層一
#將第二層max-pooling的輸出連接成爲一個一維向量,作爲該層的輸入。
wf1 = tf.Variable(tf.truncated_normal([7 * 7 * 64,1024],stddev=0.1))
bf1 = tf.Variable(tf.constant(0.1,shape=[1024]))
fc1 = tf.nn.relu(tf.matmul(tf.reshape(p2,[-1,7 * 7 * 64]),wf1) + bf1)

keep_prob = tf.placeholder(tf.float32)
fc1_drop = tf.nn.dropout(fc1,keep_prob=keep_prob)

#緊密連接層二
#Softmax層:輸出爲10,輸出的每一維都是圖片y屬於該類別的概率。
wf2 = tf.Variable(tf.truncated_normal([1024,10],stddev=0.1))
bf2 = tf.Variable(tf.constant(0.1,shape=[10]))
y = tf.nn.softmax(tf.matmul(fc1_drop,wf2) + bf2)

四、建立訓練、判斷正確、統計模型 及 訓練與測試

#建立訓練模型
loss = -tf.reduce_sum(y_data * tf.log(y))
train = tf.train.AdamOptimizer(1e-4).minimize(loss)

#建立判斷正確與統計模型
correct = tf.equal(tf.argmax(y_data,1),tf.argmax(y,1))#比較訓練集中的結果與計算的結果,返回TRUE
accuracy = tf.reduce_mean(tf.cast(correct,tf.float32))#因爲correct返回爲TRUE,轉化爲float32的1.0,對傳入的整個batch求平均值,即爲正確率

#開始訓練
sess = tf.Session()
sess.run(tf.initialize_all_variables())
for i in range(250): #在google提供的例子中循環運算是20000次,太慢了,所以我只設置到250
    batch = mnist.train.next_batch(50)#每次的運算提取訓練集中的50張圖片
    if i % 50 == 0:
        #每50次,調用一次判斷正確與統計模型
        trainTmp = sess.run(accuracy,feed_dict={x_data:batch[0],y_data:batch[1],keep_prob:1.0})
        print("第%d步,正確率:%g" % (i,trainTmp))
    #簡單說一下keep_prob設置不同,在做判斷與統計時,我們不做dropout,所以設1.0,但在訓練時我們需要進行dropout,所以設爲0.5
    sess.run(train,feed_dict={x_data:batch[0],y_data:batch[1],keep_prob:0.5})

#對測試集進行測試
#在google提供的例子是沒有[0:50]的,即計算所有測試集的值,還是因爲太慢,我就取了前50個
print("前50個測試集的正確率:" + str(sess.run(accuracy,feed_dict={x_data:mnist.test.images[0:50],y_data:mnist.test.labels[0:50],keep_prob:1})))


發佈了46 篇原創文章 · 獲贊 14 · 訪問量 2萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章