Tensorflow學習筆記:基礎篇(4)——Mnist手寫集改進版(添加隱藏層)

Tensorflow學習筆記:基礎篇(4)——Mnist手寫集改進版(添加隱藏層)


前序

— 前文中,我們的初始版本實現了一個非常簡單的兩層全連接網絡來完成MNIST數據的分類問題,輸入層784個神經元,輸出層10個神經元,最終迭代計算20次,準確率在0.91左右,本文我們採取添加隱藏層的方法進行訓練,看看效果如何
Reference:前文博客:Mnist手寫集初始版本


計算流程

1、數據準備

2、準備好placeholder

3、初始化參數/權重

4、計算預測結果

5、計算損失值

6、初始化optimizer

7、指定迭代次數,並在session執行graph


代碼示例

1、數據準備

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 載入數據集
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)

# 每個批次送100張圖片
batch_size = 100
# 計算一共有多少個批次
n_batch = mnist.train.num_examples // batch_size

2、準備好placeholder

x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])

3、初始化參數/權重

此處我們添加了兩個隱藏層,分別有500和300個神經元,這樣包括輸入輸出層,總共4層神經網絡
其中:
(1)隱藏層初始化函數建議使用tf.truncated_normal()(截短的隨機數)類型,而非前文中的tf.zero()(初始化爲零)類型
(2)中間層的激活函數,本文使用tanh(雙曲正切函數),建議讀者可以嘗試運用ReLU函數或者Sigmoid函數,比較一下輸出結果

W1 = tf.Variable(tf.truncated_normal([784, 500], stddev=0.1), name='W1')
b1 = tf.Variable(tf.zeros([500]) + 0.1, name='b1')
L1 = tf.nn.tanh(tf.matmul(x, W1) + b1, name='L1')

W2 = tf.Variable(tf.truncated_normal([500, 300], stddev=0.1), name='W2')
b2 = tf.Variable(tf.zeros([300]) + 0.1, name='b2')
L2 = tf.nn.tanh(tf.matmul(L1, W2) + b2, name='L2')

W3 = tf.Variable(tf.truncated_normal([300, 10], stddev=0.1), name='W3')
b3 = tf.Variable(tf.zeros([10]) + 0.1, name='b3')

4、計算預測結果

最後一層的激活函數依然是softmax函數

prediction = tf.nn.softmax(tf.matmul(L2, W3) + b3)

5、計算損失值

這裏我們依舊使用二次代價函數

loss = tf.reduce_mean(tf.square(y - prediction))

6、初始化optimizer

learning_rate = 0.2
optimizer =  tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)

# 結果存放在一個布爾型列表中
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))  

# 求準確率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

7、指定迭代次數,並在session執行graph

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(21):
        for batch in range(n_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            sess.run(optimizer, feed_dict={x: batch_xs, y: batch_ys})

        test_acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})

        if epoch % 2 == 0:
            print("Iter" + str(epoch) + ", Testing accuracy:" + str(test_acc))

運行結果

迭代計算20次,準確率0.95左右,準確率較上次的0.91有了一定的提升,請讀者思考並嘗試,如何繼續修改以繼續提高準確率呢~~
這裏寫圖片描述


完整代碼

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data", one_hot=True)

# 每個批次的大小
batch_size = 100
# 計算一共有多少個批次
n_batch = mnist.train.num_examples // batch_size

x = tf.placeholder(tf.float32, [None, 784], name='x_input')
y = tf.placeholder(tf.float32, [None, 10], name='y_input')


W1 = tf.Variable(tf.truncated_normal([784, 500], stddev=0.1), name='W1')
b1 = tf.Variable(tf.zeros([500]) + 0.1, name='b1')
L1 = tf.nn.tanh(tf.matmul(x, W1) + b1, name='L1')

W2 = tf.Variable(tf.truncated_normal([500, 300], stddev=0.1), name='W2')
b2 = tf.Variable(tf.zeros([300]) + 0.1, name='b2')
L2 = tf.nn.tanh(tf.matmul(L1, W2) + b2, name='L2')

W3 = tf.Variable(tf.truncated_normal([300, 10], stddev=0.1), name='W3')
b3 = tf.Variable(tf.zeros([10]) + 0.1, name='b3')

prediction = tf.nn.softmax(tf.matmul(L2, W3) + b3)

# 二次代價函數
loss = tf.reduce_mean(tf.square(y - prediction))


# 梯度下降
optimizer = tf.train.GradientDescentOptimizer(0.2).minimize(loss)


correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)

    for epoch in range(21):

        for batch in range(n_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            sess.run(optimizer, feed_dict={x: batch_xs, y: batch_ys})

        test_acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})

        if epoch % 2 == 0:
            print("Iter" + str(epoch) + ", Testing accuracy:" + str(test_acc))
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章