手寫數字識別這個例子是學tensorflow必經的小練習喔
話不多說來個正確率的圖,測試集正確率96%
啥?看着不好看?沒看到過程?莫急,看下圖,tensorboard可視化,多帥喔。
這是網絡的結構圖
訓練過程中的數據的變化
圖中可以看到asc(正確率)在隨着訓練不斷的上升,loss損失值也在不斷的下降,其中用到的
下面是code部分,大部分內容可以看代碼
創建一個簡單的神經網絡
# 這裏所說的神經網絡就是一些矩陣相乘,還要偏置的相加,最後保證輸出10個值,因爲手寫數字只有10個類別
# 可以自行改一下里面中間層的神經元個數之類的,不過要注意確保可以矩陣相乘的條件
# 裏面你可以加一層,兩層也可以的,注意輸出矩陣的格式呀求就好
# 就好比如:輸入的爲[None, 784], 因爲圖片有784個像素,也就是有784個特徵,
#所以第一層的權重必須是是 [784, 神經元的個數],偏置則是跟神經元個數一樣就行了。
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("../../tensorflow_code/mnist/", one_hot=True)
batch_size = 50
n_batch = mnist.train.num_examples // batch_size
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
with tf.variable_scope("first_net"):
W1 = tf.Variable(tf.random_normal([784, 50], stddev=0.1))
b1 = tf.Variable(tf.zeros([50]))
prediction1 = tf.nn.tanh(tf.matmul(x, W1) + b1)
L1_drop = prediction1
with tf.variable_scope("second_net"):
W2 = tf.Variable(tf.random_normal([50, 50], stddev=0.1))
b2 = tf.Variable(tf.zeros([50]))
prediction2 = tf.nn.tanh(tf.matmul(L1_drop, W2) + b2)
L2_drop = prediction2
with tf.variable_scope("third_net"):
W3 = tf.Variable(tf.random_normal([50, 50], stddev=0.1))
b3 = tf.Variable(tf.zeros([50]))
prediction3 = tf.nn.tanh(tf.matmul(L2_drop, W3) + b3)
L3_drop = prediction3
with tf.variable_scope("last_net"):
W4 = tf.Variable(tf.zeros([50, 10]))
b4 = tf.Variable(tf.zeros([10]))
prediction = tf.nn.tanh(tf.matmul(L3_drop, W4) + b4)
with tf.variable_scope("loss"):
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))
with tf.variable_scope("optimizer"):
train_step = tf.train.GradientDescentOptimizer(0.15).minimize(loss)
init = tf.global_variables_initializer()
with tf.variable_scope("accuary"):
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.summary.scalar("loss", loss)
tf.summary.scalar("asc", accuracy)
tf.summary.histogram("w4", W4)
tf.summary.histogram("b4", b4)
init_op = tf.global_variables_initializer()
merge = tf.summary.merge_all()
with tf.Session() as sess:
sess.run(init)
filewriter = tf.summary.FileWriter("./model/", graph=sess.graph)
i += 1
for epoch in range(20):
for batch in range(n_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys})
if i % 100 == 0:
summary = sess.run(merge, feed_dict={x: batch_xs, y: batch_ys})
filewriter.add_summary(summary, i)
test_acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
train_acc = sess.run(accuracy, feed_dict={x: mnist.train.images, y: mnist.train.labels})
print("Iter " + str(epoch) + ",Testing Accuracy " + str(test_acc) + ",Training Accuracy " + str(train_acc))
小結: 這次分享的主要是比較基礎的手寫數字的識別,寫得可能不是很詳細,不過可以自行看下注釋,以及自己修改一些參數可以幫助理解,對於tensorboard不能顯示出圖的問題,會在我的另一篇博客中提到,歡迎多多指教。