Tensorflow框架下識別手寫字神經網絡代碼

不借助任何架構的神經網絡代碼在代碼可讀性上能夠很好的表達出神經網絡代碼是如何工作的,但是代碼運行效率卻很低.或者說對硬件的要求很高,因爲python語言的運行效率很低.
Google的tensorflow架構很好的在硬件設備上搭建神經網絡的代碼,該架構在各個開源社區有無數教程.可以去社區瞭解tensorflow的架構與基礎.
(一) Tensorflow加載數據集
Tensorflow數據集的加載進行和模塊化的打包:

from tensorflow.examples.tutorials.mnist import input_data
if os.path.exists('/ysk/code/tensorflowcode/MNIST_data/'):
    mnist = input_data.read_data_sets("/ysk/code/tensorflowcode/MNIST_data/", one_hot=True)
else:
    print("file not exist!")

以上返回一個mnist,包括60000行的訓練數據集和10000行的測試數據集.訓練集中,訓練的圖片叫做mnist.train.images,訓練的標籤叫做mnist.train.labels
(二) 構建計算圖

#x表示輸入,行數代表訓練圖片張數,784爲一個圖片的像素,因此每一行爲一個訓練的圖片像素
x = tf.placeholder("float", [None, 784])
#W代表第一層之間的權重,這裏的權重的維度與不使用架構的維度是行列相反的
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

#這裏的y是對輸出層進行了softmax的分類.
y = tf.nn.softmax(tf.matmul(x, W) + b)
#這裏的y_表示訓練集中的標籤,是真是的數字
y_ = tf.placeholder("float", [None, 10])

#代價函數使用的是交叉熵代價函數,簡單介紹交叉熵代價函數表示的意義爲利用預測輸出(y)來表示真正的數字(y_)的困難程度,也就是說使用我們經過神經網絡訓練的輸出來與真實的標籤數字進行對比.交叉熵的值越大,說明用訓練出來的數據表示真實數據的困難程度就越大,因此神經網絡的主要目標就是降低交熵的數值.
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
#訓練採用的算法可以自己選擇,參考博客`https://www.cnblogs.com/ranjiewen/p/5938944.html`,該博客說明了各種梯度下降函數的含義以及梯度下降的快慢.
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
#train_step = tf.train.AdadeltaOptimizer(0.01).minimize(cross_entropy)
#train_step = tf.train.RMSPropOptimizer(0.01).minimize(cross_entropy)
#train_step = tf.train.AdagradOptimizer(0.01).minimize(cross_entropy)

#tensorflow中一定要將所有的變量進行初始化,只有初始化後的節點才能被應用.但是tensorflow中初始化函數使用的語句爲下面:
init_op = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init_op)

for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x:batch_xs, y_:batch_ys})
    #equal語句是將輸出的[100, 10]的矩陣與標籤的[100, 10]的矩陣進行比較,返回的矩陣是如:[true, false, false]形式的矩陣.
    cross_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    #因爲生成的矩陣爲上面個的形式,所以需要將矩陣形式轉換爲真正的數值形式
    accuracy = tf.reduce_mean(tf.cast(cross_prediction, "float"))
    #在訓練完以後要對神經網絡的準確性進行測試.
    test_acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y_:mnist.test.labels})
    if i % 10 == 0:
        print("After %d steps trainging, the accuracy is %g " %(i, test_acc))
發佈了50 篇原創文章 · 獲贊 36 · 訪問量 12萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章