Tensorflow 训练实现“与”功能的神经网络实列

针对Tensorflow初学者,下面是一个非常简单的入门程序。
程序的目标是训练一个具有2个输入神经元、3个隐含层神经元和一个输入出神经元的神经网络,实现二进制逻辑与功能。
因此训练数据只有 (0,0) (0,1) (1,0) (1,1)四组
输出为 0,0,0,1

这个程序几乎设计Tensorflow的很多细节。从张量图构造,训练,模型保存,模型验证等等。

#!/usr/bin/python3
#用TensorFlow训练“与”功能的神经网络
#道老师 2019-12-08
#该程序在 CentOS7.0  
#GPU Quadro P5000  NVIDIA-SMI 430.50       Driver Version: 430.50       CUDA Version: 10.1
#Tensorflow.version=1.10.0 下运行正常

import tensorflow as tf
import numpy as np

import time
import logging
import sys
import math

#日志打印,%(asctime)s 时间,%(message)s代表后面的要输出的内容
logging.basicConfig(format='%(asctime)s: %(message)s',
                    level=logging.DEBUG,
                    stream=sys.stdout)

MODEL_PATH = "./models/"  #模型参数保存地址
SUMMARY_PATH = "./logs/" #训练过程中的信息输出保存地址,主要用于tensorboard可视化
NUM_EPOCHS = 1000  #最大迭代次数
INPUT_SIZE = 2   #输出层神经元个数,由于是0,1的与,所以输入数据维度是2
HIDDEN_SIZE = 3  #隐藏层神经元个数(可以是其它任意整数)
OUTPUT_SIZE = 1  #输出层神经元个数,与操作结果是0/1 所以这里固定为1
train_x=np.array([[0,0],[0,1],[1,0],[1,1]]) #训练数据,维度[n,2],由于是二进制所以n最大只能为4个
train_y=np.array([[0]  ,[0],  [0],  [1]]) #训练数据标签,与运算结果。

#构造张量图
config = tf.ConfigProto()
graph = tf.Graph()
with graph.as_default():
    x = tf.placeholder(tf.float32, [None, INPUT_SIZE]) #训练数据占位符
    y = tf.placeholder(tf.float32, [None,OUTPUT_SIZE])#训练数据标签占位符
   
   
    #隐含层参数
    W21 = tf.Variable(tf.truncated_normal([INPUT_SIZE,HIDDEN_SIZE],-0.1,0.1),name = "weight21")
    b2 = tf.Variable(tf.truncated_normal([HIDDEN_SIZE],-0.1,0.1),name="biases2")
    #输出层参数
    W32 = tf.Variable(tf.truncated_normal([HIDDEN_SIZE,OUTPUT_SIZE],-0.1,0.1),name = "weight32")
    b3 = tf.Variable(tf.truncated_normal([OUTPUT_SIZE],-0.1,0.1),name="biases3")
    #前向计算,
    a2 = tf.sigmoid(tf.matmul(x,W21)+b2) #隐含层输出
    y_out  = tf.sigmoid(tf.matmul(a2,W32)+b3) #输出层输出,0~1的值
   
    
    #前面的代码块儿也可以用下面的tensorflow里的全连接层代替,这样就不用自己实现变量定义,前向计算过程了。dense层已经封装好了上述计算过程
    #dense1 = tf.layers.dense(inputs=x, units=HIDDEN_SIZE, activation=tf.nn.sigmoid)
    #y_out  = tf.layers.dense(inputs=dense1, units=OUTPUT_SIZE, activation=tf.nn.sigmoid)
    
    y_pred = tf.cast(y_out>0.5, dtype=tf.float32)#预测输出 0或1
    
    accuracy=tf.reduce_mean(tf.cast(tf.equal(y,y_pred), dtype=tf.float32)) * 100#识别准确率,就是训练标签y和预测y_pred逐一比较
    
    #交叉熵损失函数
    loss = -tf.reduce_sum(y*tf.log(y_out)+(1-y)*tf.log(1-y_out))
    #loss=tf.reduce_mean(tf.square(y_out-y))  #也可以用均方误差损失函数
    
    
    tf.summary.scalar("loss", loss)  #把loss增加到summary里,用于tensorboard可视化
    tf.summary.scalar("accuracy", accuracy)  #把accuracy增加到summary里,用于tensorboard可视化
    
    #以最小化loss作为优化目标,学习率0.1,可以设置其它值
    optimizer = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

#-----------------------图构造完毕-------------------------


#这部分是读取保存的模型,所以不需要再训练,直接就能做预测了。
#如果像用以训练的模型做预测,可以把下面5行注释方开

#with tf.Session(config=config, graph=graph) as sess:
#    saver = tf.train.Saver()
#    saver.restore(sess,MODEL_PATH+"1.ckpt")
#    logging.info(sess.run(y_pred,feed_dict={x:train_x}))
#exit()


#---------------开始训练----------------
with tf.Session(config=config, graph=graph) as sess:
    saver = tf.train.Saver() #用于保存模型参数
    merged_summary = tf.summary.merge_all() #自动管理模式
    summary_writer = tf.summary.FileWriter(SUMMARY_PATH, tf.get_default_graph())#日志保存文件名称
    tf.global_variables_initializer().run()#初始化变量
    for current_epoch in range(NUM_EPOCHS):#循环迭代
        start_time = time.time()#开始时间
        #计算一次,其中feed_dict指示占位符x,y的实际对应数据是谁
        #计算[loss,optimizer,accuracy,merged_summary]这些内容,并一次返回给train_loss,_,acc,summary
        #“_”,返回不获取
        train_loss,_,acc,summary=sess.run([loss,optimizer,accuracy,merged_summary],feed_dict={x:train_x,y:train_y})
        summary_writer.add_summary(summary, current_epoch)#写一次summary,用于可视化
        #打印训练信息
        logging.info("Epoch(time: %.3f s) %d/%d loss=%f  ACC=%.3f)",
                time.time() - start_time,
                current_epoch + 1,
                NUM_EPOCHS,
                train_loss,
                acc)
    
    #最后计算一下预测输出看看
    logging.info(sess.run(y_pred,feed_dict={x:train_x}))
    
    #模型保存
    save_path = saver.save(sess, MODEL_PATH+"1.ckpt")
    logging.info("Model saved in file: %s", save_path)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章