代碼:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
#dot
X_dot=np.linspace(-1,1,100)
Y_dot=3*X_dot+6+np.random.randn(*X_dot.shape)*0.2
X=tf.placeholder("float")
Y=tf.placeholder("float")
w=tf.Variable(0.0,name="weight")
b=tf.Variable(0.0,name="bias")
loss=tf.square(Y-X*w-b)
train_op=tf.train.GradientDescentOptimizer(0.01).minimize(loss)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
echo=1
for i in range(10):
for (x,y)in zip(X_dot,Y_dot):
_,Loss,W,B=sess.run([train_op,loss,w,b],feed_dict={Y:y,X:x})
print("echo:",echo," w:",W," b:",B," loss:",Loss)
echo+=1
plt.plot(X_dot,Y_dot)
plt.plot(X_dot,X_dot*W+B)
plt.show()
運行結果:
echo: 1 w: 0.222561 b: 6.07669 loss: 7.05683
echo: 2 w: 1.38609 b: 6.42955 loss: 1.15972
echo: 3 w: 2.11399 b: 6.25391 loss: 0.251627
echo: 4 w: 2.5018 b: 6.11855 loss: 0.0569603
echo: 5 w: 2.70129 b: 6.04374 loss: 0.0118342
echo: 6 w: 2.80303 b: 6.00492 loss: 0.0018697
echo: 7 w: 2.85481 b: 5.98508 loss: 9.94998e-05
echo: 8 w: 2.88114 b: 5.97498 loss: 4.80562e-05
echo: 9 w: 2.89453 b: 5.96984 loss: 0.000241081
echo: 10 w: 2.90134 b: 5.96723 loss: 0.000395928
一腳踩到的坑:
1.'tuple' object cannot be interpreted as an integer
在使用np.linspace對X_dot進行初始化後,如果X_dot.shape 會報上述錯誤。使用np.shape(X_dot)打印一下,輸出爲(100,),估計這就是tuple的來源。
所以這裏的解決辦法是:
1)*X_dot.shape 2)X_dot.shape[0] 2)np.shape(X_dot)[0]
2.tf.placeholder 傳參數時是“float” 神奇,和float(inf")一樣神奇
3.這裏有一百個點,訓練的時候batchsize是1,是要一個一個for (x,y)in zip(X_dot,Y_dot)傳入的,不是把X_dot,Y_dotzhij直接放到feed_dict裏面
4.sess.run([train_op,loss,w,b],feed_dict={Y:y,X:x})
fetches和feed結合,又方便了打印過程