【tensorflow入门】9、结果可视化

1、安装matplotlib  [Python中画图的工具包] ,否则会报错。

Python库导入错误:ImportError: No module named matplotlib.pyplot

   补救:

2、在上一篇代码的基础上添加可视化代码。 

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt  #Python中画图的工具包

#添加神经层的函数
def add_layer(inputs,in_size,out_size,activation_function=None):
    Weights = tf.Variable(tf.random_normal([in_size,out_size]))   #二维
    biases = tf.Variable(tf.zeros([1,out_size])) + 0.1  #1维
    
    Wx_plus_b = tf.matmul(inputs,Weights)+biases
    if activation_function is None:
        outputs = Wx_plus_b
    else:
        outputs = activation_function(Wx_plus_b)
    return outputs

x_data = np.linspace(-1,1,300)[:,np.newaxis]  #300行1列
noise = np.random.normal(0,0.05,x_data.shape)
y_data = np.square(x_data) -0.5 + noise 

xs = tf.placeholder(tf.float32,[None,1])
ys = tf.placeholder(tf.float32,[None,1])

l1 = add_layer(xs,1,10,activation_function=tf.nn.relu)
prediction = add_layer(l1,10,1,activation_function=None)

loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction),reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

init = tf.global_variables_initializer()

sess = tf.Session()
sess.run(init)


# fig = plt.figure()  #首先生成一个图片框
# ax = fig.add_subplot(1,1,1)  #要做一个连续性的plot画图需要用这个 #(1,1,1):编号
# ax.scatter(x_data,y_data)    #画散点图   #用点的形式plot上来

# #plt.ion()   #让程序plot后不暂停
# #plt.show()  #打印  真实数据 打印完整个程序暂停了

fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(x_data,y_data)
plt.ion()
plt.show()



for i in range(1000):
    sess.run(train_step,feed_dict={xs:x_data,ys:y_data})
    if i %50 == 0:
        try:
            ax.lines.remove(lines[0])
        except Exception:
            pass
        prediction_value = sess.run(prediction,feed_dict={xs:x_data})
        lines = ax.plot(x_data,prediction_value,'r-',lw=5)
        
        plt.pause(0.1)
        
        
#    if i %50 == 0:
        #print(sess.run(loss,feed_dict={xs:x_data,ys:y_data}))  
        
#         try:
#             ax.lines.remove(lines[0])  #抹去lines的第一条线段 (本来就一条,所以就是这条)
        
#         except Exception:
#             pass  #什么都不做
        
#         prediction_value = sess.run(prediction,feed_dict={xs:x_data})
#         lines = ax.plot(x_data,prediction_value,'r-',lw=5)   #用曲线形式,把prediction的值plot上去  红线,线宽为5
#         #如果想连续plot的话,plot出的每条线要先抹去,再plot另外一条线。否则就是叠加。
        
#         plt.pause(0.1)   #暂停0.1秒   
        
       

效果:动态的,这里仅为截图

前提:Python版本需要>3.5,否则应当去掉 plt.ion(),并将plt.show()改为plt.show(block=False)。否则画完散点图后程序暂停,无法继续出现折线图。

3、思路:

1、导入matplotlib包

2、真实数据用散点图表示。(x_data,y_data)

3、预测数据用折线图表示。(x_data,prediction_value)

注意:prediction是个op,prediction_value是个value

4、折线效果每0.1秒擦除并重新画。

 

4、踩坑:

我使用jupyter编写和运行python代码,jupyter默认对动态图无法显示,只能显示静态效果。

我百度到的各种方法均效果不佳,最终将代码粘贴在cmd命令行上,才能通过弹窗实现动态效果。

 

ps:

想换个写python的软件了,毕竟后面要写大数据层次的深度学习,怕jupyter带不起来。

大家有没有好用的软件推荐一下呀(づ ̄3 ̄)づ╭❤~

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章