首先注意assign在tensorflow的操作是將改變當前節點的值,並將改變值後的節點返回,這是tensorflow的api。而等於號是python裏的賦值語句,但與普通python賦值語句不同的是,由於通常tensorflow建圖時右邊的操作都是新建一個節點,所以這個等於號其實就是將變量的引用到這個新節點上。你只要分清哪些是tensorflow中的操作和哪些是python語言的操作,就能分清哪些是在建圖,哪些只是在改變引用。
import tensorflow as tf
with tf.Graph().as_default():
a = tf.Variable(1, name="a")
assign_op = tf.assign(a, tf.add(a,1))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print sess.run(assign_op)
print a.eval()
print a.eval()
2
2
2
第一種,assign_op對a進行了賦值,因此sess.run(assign_op)運行後返回值的是2,由於assign是對原始節點修改值,而不是新建節點。因此a對應節點的值也變成了2。
import tensorflow as tf
with tf.Graph().as_default():
a = tf.Variable(1, name="a")
a = a + 1
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print sess.run(a)
print a.eval()
print a.eval()
2
2
2
第二種,注意a = a + 1的實際操作是首先將右邊a的節點加上1,這是一個新建節點操作,a+1返回的是這個新建節點,此時a = 新建節點。也就是a引用的節點地址變了。這是run(a)、a.eval()其實都不會修改原始節點name="a"的值,也就是始終爲1,所以run(a)、a.eval()執行的都是1+1運算。
import tensorflow as tf
with tf.Graph().as_default():
a = tf.Variable(1, name="a")
a = tf.assign(a, tf.add(a,1))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print sess.run(a)
print a.eval()
print a.eval()
2
3
4
第三種,首先新建個節點name=“a”,再對這個節點採用assign操作,注意這個assign的操作被a給引用了。也就是每次運行run(a),a.eval()都是需要assign運行一次。這樣,assign運行第一次,name="a"節點值變爲2;assign運行第二次相當於2+1=3,同理第三次3+1。
import tensorflow as tf
counter = tf.Variable(0, name="counter")
one = tf.constant(1)
ten = tf.constant(10)
new_counter = tf.add(counter, one) # tf.add 相當於counter+one
assign = tf.assign(counter, new_counter)
result = tf.add(assign, ten)
init_op = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init_op)
for _ in range(3):
print sess.run(counter)
print sess.run(result)