本文通過搭建Softmax Regression,並用MNIST數據集進行訓練以及測試,介紹tensorflow的最基礎使用方式。
MNIST數據集介紹以及Softmax迴歸介紹參考:http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_beginners.html
MNIST數據集導入
通過調用read_data_sets(),第一個參數填MNIST數據集存儲路徑,函數會自動判斷當前路徑下是否下載好數據,是否需要重新下載。
import tensorflow.examples.tutorials.mnist.input_data as input_data mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
Softmax迴歸模型搭建
# Create the model #通過操作符號變量創建一個可交互的操作單元 x = tf.placeholder(dtype=tf.float32, shape=[None, 784]) #權重值和偏置量的創建 w = tf.Variable(tf.zeros(shape=[784, 10])) b = tf.Variable(tf.zeros(shape=[10])) #Softmax模型創建 y = tf.matmul(x, w) + b; # Define loss and optimizer y_ = tf.placeholder(tf.float32, [None, 10])
訓練模型存儲
#模型啓動 sess = tf.InteractiveSession() saver=tf.train.Saver() def train(): #交叉熵計算 cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)) #執行反向傳播 train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) tf.global_variables_initializer().run() # Train for _ in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) train() #模型存儲,默認存儲路徑爲工程同目錄下文件夾 saver.save(sess,save_path='./model/mnistmodel.ckpt')
訓練模型載入
載入模型時,必須先完整還原網絡結構的所有參數
import tensorflow as tf import numpy as np import tensorflow.examples.tutorials.mnist.input_data as input_data mnist = input_data.read_data_sets("MNIST_data",one_hot=True) myGraph = tf.Graph() #還原網絡結構 x = tf.placeholder(dtype=tf.float32, shape=[None, 784]) w = tf.Variable(tf.zeros(shape=[784, 10])) b = tf.Variable(tf.zeros(shape=[10])) y = tf.matmul(x, w) + b; # Define loss and optimizer y_ = tf.placeholder(tf.float32, [None, 10]) # Test trained model #提取變量 saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess,'model/mnistmodel.ckpt') print('Weight:\n',sess.run(w)) print('biases:\n',sess.run(b)) #test correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})) _: mnist.test.labels}))