3 TensorFlow入門之識別手寫數字

————————————————————————————————————

寫在開頭:此文參照莫煩python教程(牆裂推薦!!!)

————————————————————————————————————

分類實驗之識別手寫數字

  • 這個實驗的內容是:基於TensorFlow,實現手寫數字的識別。
  • 這裏用到的數據集是大家熟知的mnist數據集。
  • mnist有五萬多張手寫數字的圖片,每個圖片用28x28的像素矩陣表示。所以我們的輸入層每個案列的特徵個數就有28x28=784個;因爲數字有0,1,2…9共十個,所以我們的輸出層是個1x10的向量。輸出層是十個小於1的非負數,表示該預測是0,1,2…9的概率,我們選取最大概率所對應的數字作爲我們的最終預測。
  • 真實的數字表示爲該數字所對應的位置爲1,其餘位置爲0的1x10的向量。

下面就開始實驗啦!

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

#導入數據
mnist = input_data.read_data_sets('MNIST_data',one_hot=True)#如果還沒下載mnist就下載

#定義添加層
def add_layer(inputs,in_size,out_size,activation_function=None):
    #定義添加層內容,返回這層的outputs
    Weights = tf.Variable(tf.random_normal([in_size,out_size]))#Weigehts是一個in_size行、out_size列的矩陣,開始時用隨機數填滿
    biases = tf.Variable(tf.zeros([1,out_size])+0.1) #biases是一個1行out_size列的矩陣,用0.1填滿
    Wx_plus_b = tf.matmul(inputs,Weights)+biases  #預測
    if activation_function is None:  #如果沒有激勵函數,那麼outputs就是預測值
        outputs = Wx_plus_b
    else:  #如果有激勵函數,那麼outputs就是激勵函數作用於預測值之後的值
        outputs = activation_function(Wx_plus_b)
    return outputs

#定義計算正確率的函數
def t_accuracy(t_xs,t_ys):
    global prediction
    y_pre = sess.run(prediction,feed_dict={xs:t_xs})
    correct_pre = tf.equal(tf.argmax(y_pre,1),tf.argmax(t_ys,1))
    accuracy = tf.reduce_mean(tf.cast(correct_pre,tf.float32))
    result = sess.run(accuracy,feed_dict={xs:t_xs,ys:t_ys})
    return result

#定義神經網絡的輸入值和輸出值
xs = tf.placeholder(tf.float32,[None,784]) #None是不規定大小,這裏指的是案例個數,而輸入特徵個數爲28x28 = 784
ys = tf.placeholder(tf.float32,[None,10]) #Nnoe也是案例個數,不做規定;10是因爲有10個數字,所以輸出是10

#增加輸出層
prediction = add_layer(xs,784,10,activation_function=tf.nn.softmax)#這裏的激勵函數是softmax,此函數多用於多類分類

#計算誤差
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys*tf.log(prediction),reduction_indices=[1])) #此誤差計算方式和softmax配套用,效果好

#訓練
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)#學習因子爲0.5

#開始訓練
sess = tf.Session()
sess.run(tf.initialize_all_variables())

for i in range(1000):
    batch_xs,batch_ys = mnist.train.next_batch(100)   #提取數據集的100個數據,因爲原來數據太大了
    sess.run(train_step,feed_dict={xs:batch_xs,ys:batch_ys})
    if i%50 == 0:
        print (t_accuracy(mnist.test.images,mnist.test.labels))  #每隔50個,打印一下正確率。注意:這裏是要用test的數據來測試
Extracting MNIST_data\train-images-idx3-ubyte.gz
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz
0.1849
0.6537
0.7393
0.7836
0.8053
0.8203
0.8275
0.837
0.8465
0.8504
0.8567
0.8571
0.8643
0.8637
0.8664
0.8687
0.8719
0.8742
0.8763
0.8773

上面4行就是下載的mnist數據集的四個文件。然後看打印出來的正確率可知,這個網絡的預測能力是越來越好的。
下面試一下啊,抽取500個數據來訓練,看看效果如何:

for i in range(1000):
    batch_xs,batch_ys = mnist.train.next_batch(500)   #提取數據集的500個數據,因爲原來數據太大了
    sess.run(train_step,feed_dict={xs:batch_xs,ys:batch_ys})
    if i%50 == 0:
        print (t_accuracy(mnist.test.images,mnist.test.labels))  #每隔50個,打印一下正確率。注意:這裏是要用test的數據來測試
0.9001
0.9022
0.9023
0.9026
0.903
0.903
0.9037
0.9036
0.9034
0.9027
0.9041
0.903
0.9039
0.9034
0.9037
0.9046
0.9055
0.9045
0.9053
0.905

由上面打印出來的正確率可知,抽取500個數據來訓練的話,正確率會達到90%


*點擊[這兒:TensorFlow]發現更多關於TensorFlow的文章*


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章