Tensorflow 訓練實現“與”功能的神經網絡實列

針對Tensorflow初學者,下面是一個非常簡單的入門程序。
程序的目標是訓練一個具有2個輸入神經元、3個隱含層神經元和一個輸入出神經元的神經網絡,實現二進制邏輯與功能。
因此訓練數據只有 (0,0) (0,1) (1,0) (1,1)四組
輸出爲 0,0,0,1

這個程序幾乎設計Tensorflow的很多細節。從張量圖構造,訓練,模型保存,模型驗證等等。

#!/usr/bin/python3
#用TensorFlow訓練“與”功能的神經網絡
#道老師 2019-12-08
#該程序在 CentOS7.0  
#GPU Quadro P5000  NVIDIA-SMI 430.50       Driver Version: 430.50       CUDA Version: 10.1
#Tensorflow.version=1.10.0 下運行正常

import tensorflow as tf
import numpy as np

import time
import logging
import sys
import math

#日誌打印,%(asctime)s 時間,%(message)s代表後面的要輸出的內容
logging.basicConfig(format='%(asctime)s: %(message)s',
                    level=logging.DEBUG,
                    stream=sys.stdout)

MODEL_PATH = "./models/"  #模型參數保存地址
SUMMARY_PATH = "./logs/" #訓練過程中的信息輸出保存地址,主要用於tensorboard可視化
NUM_EPOCHS = 1000  #最大迭代次數
INPUT_SIZE = 2   #輸出層神經元個數,由於是0,1的與,所以輸入數據維度是2
HIDDEN_SIZE = 3  #隱藏層神經元個數(可以是其它任意整數)
OUTPUT_SIZE = 1  #輸出層神經元個數,與操作結果是0/1 所以這裏固定爲1
train_x=np.array([[0,0],[0,1],[1,0],[1,1]]) #訓練數據,維度[n,2],由於是二進制所以n最大隻能爲4個
train_y=np.array([[0]  ,[0],  [0],  [1]]) #訓練數據標籤,與運算結果。

#構造張量圖
config = tf.ConfigProto()
graph = tf.Graph()
with graph.as_default():
    x = tf.placeholder(tf.float32, [None, INPUT_SIZE]) #訓練數據佔位符
    y = tf.placeholder(tf.float32, [None,OUTPUT_SIZE])#訓練數據標籤佔位符
   
   
    #隱含層參數
    W21 = tf.Variable(tf.truncated_normal([INPUT_SIZE,HIDDEN_SIZE],-0.1,0.1),name = "weight21")
    b2 = tf.Variable(tf.truncated_normal([HIDDEN_SIZE],-0.1,0.1),name="biases2")
    #輸出層參數
    W32 = tf.Variable(tf.truncated_normal([HIDDEN_SIZE,OUTPUT_SIZE],-0.1,0.1),name = "weight32")
    b3 = tf.Variable(tf.truncated_normal([OUTPUT_SIZE],-0.1,0.1),name="biases3")
    #前向計算,
    a2 = tf.sigmoid(tf.matmul(x,W21)+b2) #隱含層輸出
    y_out  = tf.sigmoid(tf.matmul(a2,W32)+b3) #輸出層輸出,0~1的值
   
    
    #前面的代碼塊兒也可以用下面的tensorflow裏的全連接層代替,這樣就不用自己實現變量定義,前向計算過程了。dense層已經封裝好了上述計算過程
    #dense1 = tf.layers.dense(inputs=x, units=HIDDEN_SIZE, activation=tf.nn.sigmoid)
    #y_out  = tf.layers.dense(inputs=dense1, units=OUTPUT_SIZE, activation=tf.nn.sigmoid)
    
    y_pred = tf.cast(y_out>0.5, dtype=tf.float32)#預測輸出 0或1
    
    accuracy=tf.reduce_mean(tf.cast(tf.equal(y,y_pred), dtype=tf.float32)) * 100#識別準確率,就是訓練標籤y和預測y_pred逐一比較
    
    #交叉熵損失函數
    loss = -tf.reduce_sum(y*tf.log(y_out)+(1-y)*tf.log(1-y_out))
    #loss=tf.reduce_mean(tf.square(y_out-y))  #也可以用均方誤差損失函數
    
    
    tf.summary.scalar("loss", loss)  #把loss增加到summary裏,用於tensorboard可視化
    tf.summary.scalar("accuracy", accuracy)  #把accuracy增加到summary裏,用於tensorboard可視化
    
    #以最小化loss作爲優化目標,學習率0.1,可以設置其它值
    optimizer = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

#-----------------------圖構造完畢-------------------------


#這部分是讀取保存的模型,所以不需要再訓練,直接就能做預測了。
#如果像用以訓練的模型做預測,可以把下面5行註釋方開

#with tf.Session(config=config, graph=graph) as sess:
#    saver = tf.train.Saver()
#    saver.restore(sess,MODEL_PATH+"1.ckpt")
#    logging.info(sess.run(y_pred,feed_dict={x:train_x}))
#exit()


#---------------開始訓練----------------
with tf.Session(config=config, graph=graph) as sess:
    saver = tf.train.Saver() #用於保存模型參數
    merged_summary = tf.summary.merge_all() #自動管理模式
    summary_writer = tf.summary.FileWriter(SUMMARY_PATH, tf.get_default_graph())#日誌保存文件名稱
    tf.global_variables_initializer().run()#初始化變量
    for current_epoch in range(NUM_EPOCHS):#循環迭代
        start_time = time.time()#開始時間
        #計算一次,其中feed_dict指示佔位符x,y的實際對應數據是誰
        #計算[loss,optimizer,accuracy,merged_summary]這些內容,並一次返回給train_loss,_,acc,summary
        #“_”,返回不獲取
        train_loss,_,acc,summary=sess.run([loss,optimizer,accuracy,merged_summary],feed_dict={x:train_x,y:train_y})
        summary_writer.add_summary(summary, current_epoch)#寫一次summary,用於可視化
        #打印訓練信息
        logging.info("Epoch(time: %.3f s) %d/%d loss=%f  ACC=%.3f)",
                time.time() - start_time,
                current_epoch + 1,
                NUM_EPOCHS,
                train_loss,
                acc)
    
    #最後計算一下預測輸出看看
    logging.info(sess.run(y_pred,feed_dict={x:train_x}))
    
    #模型保存
    save_path = saver.save(sess, MODEL_PATH+"1.ckpt")
    logging.info("Model saved in file: %s", save_path)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章