在tensorflow的使用中,经常会看到tf.add_to_collection()和tf.get_collection()这两个函数成对的出现,那么这两个函数到底有什么用呢?
用个小例子,解释一下
import tensorflow as tf
import numpy as np
v1 = tf.get_variable(name='v1', shape=[1], initializer=tf.constant_initializer(1))
tf.add_to_collection('loss', v1)
v2 = tf.get_variable(name='v2', shape=[1], initializer=tf.constant_initializer(2))
tf.add_to_collection('loss', v2)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
print (tf.get_collection('loss'))
print(sess.run(tf.get_collection('loss')))
print(type(sess.run(tf.get_collection('loss'))))
print (sess.run(tf.add_n(tf.get_collection('loss'))))
可以看到代码中我们创建了两个tensor,一个初始化为1,一个初始化为2,通过使用tf.add_to_collection()把两个变量加载到一个名为‘loss’的list里面,tf.get_collection()可以获得对应的list里面存储的变量,最后通过tf.add_n把‘loss’里面的变量相加求和。
[<tf.Variable 'v1:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'v2:0' shape=(1,) dtype=float32_ref>]
[array([1.], dtype=float32), array([2.], dtype=float32)]
<class 'list'>
[3.]
看到代码中的命名,我们也可以推测一下这个在tf里面经常可以用来保存多个loss的值。