TensorFlow MNIST手寫數字識別(最佳實踐版)

(1) 引入函數庫

import numpy as np
import matplotlib.pyplot as plt
import os 
import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.examples.tutorials.mnist import input_data

(2)加載數據

mnist = input_data.read_data_sets("datasets/MNIST_data/", one_hot=True)

(3)定義參數

learning_rate = 0.0001
num_epochs = 10000
BATCH_SIZE = 100

LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99
MOVING_AVERAGE_DECCAY = 0.99
REGULARIZER_RATE = 0.0001
MODEL_SAVE_PATH = "MNIST_model/"
MODEL_NAME = "mnist_model"

(m,n_x) = mnist.train.images.shape #784
n_y = mnist.train.labels.shape[1] #10
n_1 = 500
costs = []

tf.set_random_seed(1)

(4)初始化參數

def init_para():
    W1 = tf.get_variable("w1",[n_x,n_1],initializer = tf.contrib.layers.xavier_initializer(seed = 1))   #(784,500)
    b1 = tf.get_variable("b1",[1,n_1],  initializer = tf.zeros_initializer())                          #(1,500)
    W2 = tf.get_variable("w2",[n_1,n_y],initializer = tf.contrib.layers.xavier_initializer(seed = 1))  #(500,10)
    b2 = tf.get_variable("b2",[1,n_y],  initializer = tf.zeros_initializer())                          #(1,10)
    
    return W1,b1,W2,b2

(5)正向傳播

def forward(X, parameters, regularizer, variable_averages):
    W1,b1,W2,b2 = parameters
    # 正則化
    if regularizer != None:
        tf.add_to_collection('losses',regularizer(W1))
        tf.add_to_collection('losses',regularizer(W2))
    #滑動平均
    if variable_averages != None:
        Z1 = tf.nn.relu(tf.matmul(X,variable_averages.average(W1)) + variable_averages.average(b1)) #(55000,500)
        Z2 = tf.matmul(Z1,variable_averages.average(W2)) + variable_averages.average(b2)            #(55000,10)
    else:
        Z1 = tf.nn.relu(tf.matmul(X,W1) + b1) #(55000,500)
        Z2 = tf.matmul(Z1,W2) + b2            #(55000,10)
    
    return Z2

(6)模型訓練

def train():
    X = tf.placeholder(tf.float32, shape=(None,n_x), name="X")  #(55000,784)
    Y = tf.placeholder(tf.float32, shape=(None,n_y), name="Y")  #(55000,10)
    
    prameters = init_para()
    global_step = tf.Variable(0, trainable = False)
    #正則化
    regularizer = tf.contrib.layers.l2_regularizer(REGULARIZER_RATE) 
    #滑動平均
    variable_averages = tf.train.ExponentialMovingAverage(LEARNING_RATE_DECAY,global_step)
    variable_averages_op = variable_averages.apply(tf.trainable_variables())
    
    Y_ = forward(X, prameters, regularizer, None)
    Y_avg = forward(X, prameters, None, variable_averages)
    #交叉熵損失函數
    cem = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits = Y_, labels = Y))
    cost = cem + tf.add_n(tf.get_collection('losses'))
    #指數衰減
    learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,m/BATCH_SIZE,
                    LEARNING_RATE_DECAY,staircase=True)
    
    optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost,global_step = global_step)
    #optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost,global_step = global_step) #不適用指數衰減
    with tf.control_dependencies([optimizer,variable_averages_op]):
        train_op= tf.no_op(name = 'train')
    
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
        tf.initialize_all_variables().run()
        
        ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)        
        if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess,ckpt.model_checkpoint_path)
        
        for i in range(num_epochs):
            x,y = mnist.train.next_batch(BATCH_SIZE)
            sess.run(train_op,feed_dict={X:x,Y:y})
            
            if i%500 == 0:
                cost_v = sess.run(cost,feed_dict={X:x,Y:y})
                costs.append(cost_v)
                saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step = global_step)
                print(i,cost_v)
            
       # Calculate the correct accuracy
        correct_prediction = tf.equal(tf.argmax(Y_avg,1), tf.argmax(Y,1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
        print ("Train Accuracy:", accuracy.eval({X:mnist.train.images, Y: mnist.train.labels}))  #
        print ("Test Accuracy:", accuracy.eval({X: mnist.test.images, Y: mnist.test.labels}))
        
    plt.plot(np.squeeze(costs))
    plt.ylabel('cost')
    plt.xlabel('iterations (per tens)')
    plt.title("Learning rate =" + str(learning_rate))
    plt.show()

(7)模型評估

def evaluate(mnist):
    with tf.Graph().as_default() as g:
        X = tf.placeholder(tf.float32, shape=(None,n_x), name="X")  #(55000,784)
        Y = tf.placeholder(tf.float32, shape=(None,n_y), name="Y")  #(55000,10)
        test_feed = {X: mnist.test.images, Y: mnist.test.labels}
        
        prameters = init_para()
        Y_ = forward(X, prameters, None, None)
        correct_prediction = tf.equal(tf.argmax(Y_,1), tf.argmax(Y,1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
        
        variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECCAY)
        variable_averages_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variable_averages_restore)
        
        with tf.Session() as sess:
            ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess,ckpt.model_checkpoint_path)
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                accuracy_feed = sess.run(accuracy, feed_dict = test_feed)
                print("After %s training steps, valadation accuracy = %g" %(global_step,accuracy_feed))
            else:
                print("No checkpoint file found")

(8)主程序

if __name__ =='__main__':
    ops.reset_default_graph()   
    train()
    evaluate(mnist)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章