使用tensorflow進行手寫數字識別

首先要在對應的目錄下安裝好手寫數字識別數據集。

編寫代碼如下所示:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("F:/anaconda/workspace/Data/MNIST_data",one_hot=True)
#設置每個批次的大小,一次運算100張圖片
batch_size = 100
#計算共有多少批次
n_batch = mnist.train.num_examples  // batch_size
#創建兩個placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])

#創建簡單的神經網絡
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.sigmoid(tf.matmul(x,W)+b)
#二次代價函數
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
# train_step = tf.train.AdamOptimizer(0.01).minimize(loss)
#初始化變量
init = tf.global_variables_initializer()

#結果存放在一個布爾類型列表中 argmax:返回一位張量中的最大值所在的位置(概率最大的位置)
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
#計算準確率 cast:把true轉化爲1.0,false轉化爲0.0
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(21):
        for bach in range(n_batch):
            batch_xs,batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
        #計算準確率
        acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
        print("Iter "+ str(epoch) + "Testing Accuracy "+ str(acc))

代價 函數可以更換,本文使用了兩種代價函數,一個是二次代價函數另一個是交叉熵代價函數,進行20次訓練後的準確率爲:

#交叉熵
Iter 0Testing Accuracy 0.8666
Iter 1Testing Accuracy 0.8774
Iter 2Testing Accuracy 0.8841
Iter 3Testing Accuracy 0.8874
Iter 4Testing Accuracy 0.8895
Iter 5Testing Accuracy 0.893
Iter 6Testing Accuracy 0.8944
Iter 7Testing Accuracy 0.8971
Iter 8Testing Accuracy 0.8972
Iter 9Testing Accuracy 0.8968
Iter 10Testing Accuracy 0.8996
Iter 11Testing Accuracy 0.8998
Iter 12Testing Accuracy 0.9011
Iter 13Testing Accuracy 0.9014
Iter 14Testing Accuracy 0.9009
Iter 15Testing Accuracy 0.9014
Iter 16Testing Accuracy 0.9016
Iter 17Testing Accuracy 0.9021
Iter 18Testing Accuracy 0.9032
Iter 19Testing Accuracy 0.9034
Iter 20Testing Accuracy 0.903

#二次代價函數
Iter 0Testing Accuracy 0.8175
Iter 1Testing Accuracy 0.8515
Iter 2Testing Accuracy 0.8639
Iter 3Testing Accuracy 0.8709
Iter 4Testing Accuracy 0.8769
Iter 5Testing Accuracy 0.8809
Iter 6Testing Accuracy 0.8844
Iter 7Testing Accuracy 0.8865
Iter 8Testing Accuracy 0.8896
Iter 9Testing Accuracy 0.8907
Iter 10Testing Accuracy 0.8921
Iter 11Testing Accuracy 0.8933
Iter 12Testing Accuracy 0.8947
Iter 13Testing Accuracy 0.8962
Iter 14Testing Accuracy 0.8965
Iter 15Testing Accuracy 0.897
Iter 16Testing Accuracy 0.8985
Iter 17Testing Accuracy 0.8989
Iter 18Testing Accuracy 0.8994
Iter 19Testing Accuracy 0.8999
Iter 20Testing Accuracy 0.9005

看起來兩者的差距並不是很大。在這裏的代價函數和優化器自己可以調整。

 

更多內容請掃描下方二維碼關注博主微信公衆號:程序員大管

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章