Tensorflow: tf.add_to_collection()和tf.get_collection()的用法

在tensorflow的使用中,經常會看到tf.add_to_collection()和tf.get_collection()這兩個函數成對的出現,那麼這兩個函數到底有什麼用呢?

用個小例子,解釋一下

import tensorflow as tf
    import numpy as np  
     
    v1 = tf.get_variable(name='v1', shape=[1], initializer=tf.constant_initializer(1))
    tf.add_to_collection('loss', v1)
    v2 = tf.get_variable(name='v2', shape=[1], initializer=tf.constant_initializer(2))
    tf.add_to_collection('loss', v2)
     
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        print (tf.get_collection('loss'))
        print(sess.run(tf.get_collection('loss')))
        print(type(sess.run(tf.get_collection('loss'))))
        print (sess.run(tf.add_n(tf.get_collection('loss'))))


可以看到代碼中我們創建了兩個tensor,一個初始化爲1,一個初始化爲2,通過使用tf.add_to_collection()把兩個變量加載到一個名爲‘loss’的list裏面,tf.get_collection()可以獲得對應的list裏面存儲的變量,最後通過tf.add_n把‘loss’裏面的變量相加求和。

[<tf.Variable 'v1:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'v2:0' shape=(1,) dtype=float32_ref>]
[array([1.], dtype=float32), array([2.], dtype=float32)]
<class 'list'>
[3.]

看到代碼中的命名,我們也可以推測一下這個在tf裏面經常可以用來保存多個loss的值。

總結一下

tf.add_to_collection:把變量放入一個list中,把很多變量變成一個列表

tf.get_collection:從一個列表中取出全部變量,獲得的對象是一個列表

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章