機器學習筆記 tensorflow mnist上實現CNN網絡

mnist上的一個普通cnn例子,採用兩層卷積和池化層加一層全連接,爲了防止過擬合在全連接層用了dropout,是一個十分簡單的例子

import tensorflow as tf
import input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)

x = tf.placeholder("float", [None, 784])
y_ = tf.placeholder("float", [None, 10])

x_image = tf.reshape(x, [-1, 28, 28, 1])
# layer_one
filter_one = tf.Variable(tf.truncated_normal([5, 5, 1, 32], stddev=0.1))
filter_one_bias = tf.Variable(tf.zeros([32])+0.1)

filter_one_h = tf.nn.relu(tf.nn.conv2d(x_image, filter_one, strides=[1, 1, 1, 1], padding="SAME")+filter_one_bias)
filter_one_out = tf.nn.max_pool(filter_one_h, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")
# layer_two

filter_two = tf.Variable(tf.truncated_normal([5, 5, 32, 64], stddev=0.1))
filter_two_bias = tf.Variable(tf.zeros([64])+0.1)

filter_two_h = tf.nn.relu(tf.nn.conv2d(filter_one_out, filter_two, strides=[1, 1, 1, 1], padding="SAME")+filter_two_bias)
filter_two_out = tf.nn.max_pool(filter_two_h, ksize=[1, 2, 2,1], strides=[1, 2, 2, 1], padding="SAME")

# full_connect

full_connect_w = tf.Variable(tf.truncated_normal([7*7*64, 1024], stddev=0.1))
full_connect_bias = tf.Variable(tf.zeros([1024])+0.1)

full_connect_out = tf.nn.relu(tf.matmul(tf.reshape(filter_two_out, [-1, 7*7*64]), full_connect_w)+full_connect_bias)

# drop_out
keep_drop = tf.placeholder("float")
drop_out = tf.nn.dropout(full_connect_out, keep_drop)

# softmax_loss
w = tf.Variable(tf.truncated_normal([1024, 10], stddev=0.1))
b = tf.Variable(tf.zeros([10])+0.1)

loss = tf.nn.softmax(tf.matmul(drop_out, w)+b)

cross_entropy = -tf.reduce_sum(y_ * tf.log(loss)) #計算交叉熵
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) #使用adam優化器來以0.0001的學習率來進行微調
correct_prediction = tf.equal(tf.argmax(loss,1), tf.argmax(y_,1)) #判斷預測標籤和實際標籤是否匹配
accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))

sess = tf.Session() #啓動創建的模型

sess.run(tf.global_variables_initializer()) #初始化變量

for i in range(5000): #開始訓練模型,循環訓練5000次
    batch = mnist.train.next_batch(50) #batch大小設置爲50
    if i % 100 == 0:
        train_accuracy = accuracy.eval(session = sess,
                                       feed_dict = {x:batch[0], y_:batch[1], keep_drop:1.0})
        print("step %d, train_accuracy %g" %(i, train_accuracy))
    train_step.run(session = sess, feed_dict = {x:batch[0], y_:batch[1],
                   keep_drop:0.5}) #神經元輸出保持不變的概率 keep_prob 爲0.5

print("test accuracy %g" %accuracy.eval(session = sess,
      feed_dict = {x:mnist.test.images, y_:mnist.test.labels,
                   keep_drop:1.0})) #神經元輸出保持不變的概率 keep_prob 爲 1,即不變,一直保持輸出

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章