tensorflow學習系列一:helloworld之minist

Tensorflow作爲當今主流的深度學習框架學習下還是很有必要的,相比其他的框架,tensorflow具有更大的靈活性;同時在資料方面也更具優勢,因此,tensorflow無疑將是我們進行深度學習開發的首選工具。好了,不多廢話了,總之,適合自己的纔是最好的,畢竟鞋子合不合腳自己穿了才知道。下面是基於mnist的一個hellworld程序,主要是採用softmaxRession對mnist進行分類,藉此瞭解tensorflow的基本流程:

(1)定義算法公式,也就是神經網絡的前向計算

(2)定義loss,選定優化器,並制定優化器的優化方式

(3)迭代對數據進行訓練

(4)在測試集或驗證集上對準確率進行評測

# -*- coding: utf-8 -*-
"""
Created on Sat Jun  9 12:20:04 2018

@author: kuang yong jian
"""
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data/',one_hot = True)

#定義輸入端口
x = tf.placeholder(tf.float32,[None,784])
y_ = tf.placeholder(tf.float32,[None,10])

#定義參數變量
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))

#算法實現
predict = tf.nn.softmax(tf.matmul(x,W) + b)

#定義損失函數
loss = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(predict),reduction_indices = [1]))

#定義優化算法
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)

#全局參數初始化,這個一定要記得
init = tf.global_variables_initializer()

#定義會話窗口
sess = tf.Session()
sess.run(init)

#迭代模型訓練
for i in range(1000):
    batch_x,batch_y = mnist.train.next_batch(100)
    train_step.run(session = sess,feed_dict = {x:batch_x,y_:batch_y}) #也可以寫成:sess.run(train_step,feed_dict = {x:batch_x,y_:batch_y})
    #sess.run(train_step,feed_dict = {x:batch_x,y_:batch_y})
    

#模型性能評估    
correct_predict = tf.equal(tf.argmax(predict,1),tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_predict,tf.float32))
print(accuracy.eval(session = sess,feed_dict = {x:mnist.test.images,y_:mnist.test.labels}))

若有不當之處,請指教,謝謝!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章