Tensorflow: tf.add_to_collection()和tf.get_collection()的用法

在tensorflow的使用中,经常会看到tf.add_to_collection()和tf.get_collection()这两个函数成对的出现,那么这两个函数到底有什么用呢?

用个小例子,解释一下

import tensorflow as tf
    import numpy as np  
     
    v1 = tf.get_variable(name='v1', shape=[1], initializer=tf.constant_initializer(1))
    tf.add_to_collection('loss', v1)
    v2 = tf.get_variable(name='v2', shape=[1], initializer=tf.constant_initializer(2))
    tf.add_to_collection('loss', v2)
     
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        print (tf.get_collection('loss'))
        print(sess.run(tf.get_collection('loss')))
        print(type(sess.run(tf.get_collection('loss'))))
        print (sess.run(tf.add_n(tf.get_collection('loss'))))


可以看到代码中我们创建了两个tensor,一个初始化为1,一个初始化为2,通过使用tf.add_to_collection()把两个变量加载到一个名为‘loss’的list里面,tf.get_collection()可以获得对应的list里面存储的变量,最后通过tf.add_n把‘loss’里面的变量相加求和。

[<tf.Variable 'v1:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'v2:0' shape=(1,) dtype=float32_ref>]
[array([1.], dtype=float32), array([2.], dtype=float32)]
<class 'list'>
[3.]

看到代码中的命名,我们也可以推测一下这个在tf里面经常可以用来保存多个loss的值。

总结一下

tf.add_to_collection:把变量放入一个list中,把很多变量变成一个列表

tf.get_collection:从一个列表中取出全部变量,获得的对象是一个列表

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章