1、安装matplotlib [Python中画图的工具包] ,否则会报错。
Python库导入错误:ImportError: No module named matplotlib.pyplot
补救:
2、在上一篇代码的基础上添加可视化代码。
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt #Python中画图的工具包
#添加神经层的函数
def add_layer(inputs,in_size,out_size,activation_function=None):
Weights = tf.Variable(tf.random_normal([in_size,out_size])) #二维
biases = tf.Variable(tf.zeros([1,out_size])) + 0.1 #1维
Wx_plus_b = tf.matmul(inputs,Weights)+biases
if activation_function is None:
outputs = Wx_plus_b
else:
outputs = activation_function(Wx_plus_b)
return outputs
x_data = np.linspace(-1,1,300)[:,np.newaxis] #300行1列
noise = np.random.normal(0,0.05,x_data.shape)
y_data = np.square(x_data) -0.5 + noise
xs = tf.placeholder(tf.float32,[None,1])
ys = tf.placeholder(tf.float32,[None,1])
l1 = add_layer(xs,1,10,activation_function=tf.nn.relu)
prediction = add_layer(l1,10,1,activation_function=None)
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction),reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
# fig = plt.figure() #首先生成一个图片框
# ax = fig.add_subplot(1,1,1) #要做一个连续性的plot画图需要用这个 #(1,1,1):编号
# ax.scatter(x_data,y_data) #画散点图 #用点的形式plot上来
# #plt.ion() #让程序plot后不暂停
# #plt.show() #打印 真实数据 打印完整个程序暂停了
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(x_data,y_data)
plt.ion()
plt.show()
for i in range(1000):
sess.run(train_step,feed_dict={xs:x_data,ys:y_data})
if i %50 == 0:
try:
ax.lines.remove(lines[0])
except Exception:
pass
prediction_value = sess.run(prediction,feed_dict={xs:x_data})
lines = ax.plot(x_data,prediction_value,'r-',lw=5)
plt.pause(0.1)
# if i %50 == 0:
#print(sess.run(loss,feed_dict={xs:x_data,ys:y_data}))
# try:
# ax.lines.remove(lines[0]) #抹去lines的第一条线段 (本来就一条,所以就是这条)
# except Exception:
# pass #什么都不做
# prediction_value = sess.run(prediction,feed_dict={xs:x_data})
# lines = ax.plot(x_data,prediction_value,'r-',lw=5) #用曲线形式,把prediction的值plot上去 红线,线宽为5
# #如果想连续plot的话,plot出的每条线要先抹去,再plot另外一条线。否则就是叠加。
# plt.pause(0.1) #暂停0.1秒
效果:动态的,这里仅为截图
前提:Python版本需要>3.5,否则应当去掉 plt.ion(),并将plt.show()改为plt.show(block=False)。否则画完散点图后程序暂停,无法继续出现折线图。
3、思路:
1、导入matplotlib包
2、真实数据用散点图表示。(x_data,y_data)
3、预测数据用折线图表示。(x_data,prediction_value)
注意:prediction是个op,prediction_value是个value
4、折线效果每0.1秒擦除并重新画。
4、踩坑:
我使用jupyter编写和运行python代码,jupyter默认对动态图无法显示,只能显示静态效果。
我百度到的各种方法均效果不佳,最终将代码粘贴在cmd命令行上,才能通过弹窗实现动态效果。
ps:
想换个写python的软件了,毕竟后面要写大数据层次的深度学习,怕jupyter带不起来。
大家有没有好用的软件推荐一下呀(づ ̄3 ̄)づ╭❤~