在tensorflow的使用中,經常會看到tf.add_to_collection()和tf.get_collection()這兩個函數成對的出現,那麼這兩個函數到底有什麼用呢?
用個小例子,解釋一下
import tensorflow as tf
import numpy as np
v1 = tf.get_variable(name='v1', shape=[1], initializer=tf.constant_initializer(1))
tf.add_to_collection('loss', v1)
v2 = tf.get_variable(name='v2', shape=[1], initializer=tf.constant_initializer(2))
tf.add_to_collection('loss', v2)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
print (tf.get_collection('loss'))
print(sess.run(tf.get_collection('loss')))
print(type(sess.run(tf.get_collection('loss'))))
print (sess.run(tf.add_n(tf.get_collection('loss'))))
可以看到代碼中我們創建了兩個tensor,一個初始化爲1,一個初始化爲2,通過使用tf.add_to_collection()把兩個變量加載到一個名爲‘loss’的list裏面,tf.get_collection()可以獲得對應的list裏面存儲的變量,最後通過tf.add_n把‘loss’裏面的變量相加求和。
[<tf.Variable 'v1:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'v2:0' shape=(1,) dtype=float32_ref>]
[array([1.], dtype=float32), array([2.], dtype=float32)]
<class 'list'>
[3.]
看到代碼中的命名,我們也可以推測一下這個在tf裏面經常可以用來保存多個loss的值。