tensorflow tf.assign 和 = + 區別

首先注意assign在tensorflow的操作是將改變當前節點的值,並將改變值後的節點返回,這是tensorflow的api。而等於號是python裏的賦值語句,但與普通python賦值語句不同的是,由於通常tensorflow建圖時右邊的操作都是新建一個節點,所以這個等於號其實就是將變量的引用到這個新節點上。你只要分清哪些是tensorflow中的操作和哪些是python語言的操作,就能分清哪些是在建圖,哪些只是在改變引用。

import tensorflow as tf
with tf.Graph().as_default():
  a = tf.Variable(1, name="a")
  assign_op = tf.assign(a, tf.add(a,1))
  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(assign_op)
    print a.eval()
    print a.eval()

2
2
2

第一種,assign_op對a進行了賦值,因此sess.run(assign_op)運行後返回值的是2,由於assign是對原始節點修改值,而不是新建節點。因此a對應節點的值也變成了2。

import tensorflow as tf
with tf.Graph().as_default():
  a = tf.Variable(1, name="a")
  a = a + 1
  with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())
   print sess.run(a)
   print a.eval()
   print a.eval()
   
2
2
2

第二種,注意a = a + 1的實際操作是首先將右邊a的節點加上1,這是一個新建節點操作,a+1返回的是這個新建節點,此時a = 新建節點。也就是a引用的節點地址變了。這是run(a)、a.eval()其實都不會修改原始節點name="a"的值,也就是始終爲1,所以run(a)、a.eval()執行的都是1+1運算。

import tensorflow as tf
with tf.Graph().as_default():
  a = tf.Variable(1, name="a")
  a = tf.assign(a, tf.add(a,1))
  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(a)
    print a.eval()
    print a.eval()
    
2
3
4

第三種,首先新建個節點name=“a”,再對這個節點採用assign操作,注意這個assign的操作被a給引用了。也就是每次運行run(a),a.eval()都是需要assign運行一次。這樣,assign運行第一次,name="a"節點值變爲2;assign運行第二次相當於2+1=3,同理第三次3+1。

import tensorflow as tf

counter = tf.Variable(0, name="counter")
one = tf.constant(1)
ten = tf.constant(10)
new_counter = tf.add(counter, one)  # tf.add 相當於counter+one
assign = tf.assign(counter, new_counter)
result = tf.add(assign, ten)
init_op = tf.initialize_all_variables()
with tf.Session() as sess:
  sess.run(init_op)
  for _ in range(3):
    print sess.run(counter)
    print sess.run(result)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章