本文通过搭建Softmax Regression,并用MNIST数据集进行训练以及测试,介绍tensorflow的最基础使用方式。
MNIST数据集介绍以及Softmax回归介绍参考:http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_beginners.html
MNIST数据集导入
通过调用read_data_sets(),第一个参数填MNIST数据集存储路径,函数会自动判断当前路径下是否下载好数据,是否需要重新下载。
import tensorflow.examples.tutorials.mnist.input_data as input_data mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
Softmax回归模型搭建
# Create the model #通过操作符号变量创建一个可交互的操作单元 x = tf.placeholder(dtype=tf.float32, shape=[None, 784]) #权重值和偏置量的创建 w = tf.Variable(tf.zeros(shape=[784, 10])) b = tf.Variable(tf.zeros(shape=[10])) #Softmax模型创建 y = tf.matmul(x, w) + b; # Define loss and optimizer y_ = tf.placeholder(tf.float32, [None, 10])
训练模型存储
#模型启动 sess = tf.InteractiveSession() saver=tf.train.Saver() def train(): #交叉熵计算 cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)) #执行反向传播 train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) tf.global_variables_initializer().run() # Train for _ in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) train() #模型存储,默认存储路径为工程同目录下文件夹 saver.save(sess,save_path='./model/mnistmodel.ckpt')
训练模型载入
载入模型时,必须先完整还原网络结构的所有参数
import tensorflow as tf import numpy as np import tensorflow.examples.tutorials.mnist.input_data as input_data mnist = input_data.read_data_sets("MNIST_data",one_hot=True) myGraph = tf.Graph() #还原网络结构 x = tf.placeholder(dtype=tf.float32, shape=[None, 784]) w = tf.Variable(tf.zeros(shape=[784, 10])) b = tf.Variable(tf.zeros(shape=[10])) y = tf.matmul(x, w) + b; # Define loss and optimizer y_ = tf.placeholder(tf.float32, [None, 10]) # Test trained model #提取变量 saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess,'model/mnistmodel.ckpt') print('Weight:\n',sess.run(w)) print('biases:\n',sess.run(b)) #test correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})) _: mnist.test.labels}))