【TensorFlow動手玩】常用集合: Variable, Summary, 自定義

集合

tensorflow用集合colletion組織不同類別的對象。tf.GraphKeys中包含了所有默認集合的名稱。

collection提供了一種“零存整取”的思路:在任意位置,任意層次都可以創造對象,存入相應collection中;創造完成後,統一從一個collection中取出一類變量,施加相應操作。

例如,tf.Optimizer只優化tf.GraphKeys.TRAINABLE_VARIABLES中的變量。

本文介紹幾個常用集合
- Variable集合:模型參數
- Summary集合:監測
- 自定義集合

Variable

Variable被收集在名爲tf.GraphKeys.VARIABLEScolletion

定義

Tensorflow使用Variable類表達、更新、存儲模型參數。

Variable是在可變更的,具有保持性的內存句柄,存儲着Tensor。必須使用Tensor進行初始化。

k = tf.Variable(tf.random_normal([]), name='k')

創建的Variable被添加到默認的collection中。

初始化

在整個session運行之前,圖中的全部Variable必須被初始化。

sess = tf.Session()
init = tf.initialize_all_variables() 
sess.run(init)

在執行完初始化之後,Variable中的值生成完畢,不會再變化。

特別強調Variable的值在sess.run(init)之後就確定了;Tensor的值要在sess.run(x)之後才確定。

獲取

Tensor, Operation一樣,Variable也是全局的。
可以通過tf.all_variables()查看所有tf.GraphKeys.VARIABLES中的對象:

# example for y = k*x
x = tf.constant(1.0, shape=[])      # 0D tensor
k = tf.Variable(tf.constant(0.5, shape=[]) )
y = tf.mul(x, k)
v = tf.all_variables()

也可以用通用方法直接訪問collection

v = tf.get_collection(tf.GraphKeys.VARIABLES)

各類Variable

另外,tensorflow還維護另外幾個collection

函數 集合名 意義
tf.all_variables() VARIABLES 存儲和讀取checkpoints時,使用其中所有變量
tf.trainable_variables() TRAINABLE_VARIABLES 訓練時,更新其中所有變量
tf.moving_average_variables() MOVING_AVERAGE_VARIABLES ExponentialMovingAverage對象會生成此類變量
tf.local_variables() LOCAL_VARIABLES all_variables()之外,需要用tf.init_local_variables()初始化
tf.model_variables() MODEL_VARIABLES

Summary

Summary被收集在名爲tf.GraphKeys.SUMMARIEScolletion

定義

Summary是對網絡中Tensor取值進行監測的一種Operation。這些操作在圖中是“外圍”操作,不影響數據流本身。

用例

我們模仿常見的訓練過程,創建一個最簡單的用例。

# 迭代的計數器
global_step = tf.Variable(0, trainable=False)
# 迭代的+1操作
increment_op = tf.assign_add(global_step, tf.constant(1))
# 實例應用中,+1操作往往在`tf.train.Optimizer.apply_gradients`內部完成。

# 創建一個根據計數器衰減的Tensor
lr = tf.train.exponential_decay(0.1, global_step, decay_steps=1, decay_rate=0.9, staircase=False)

# 把Tensor添加到觀測中
tf.scalar_summary('learning_rate', lr)

# 並獲取所有監測的操作`sum_opts`
sum_ops = tf.merge_all_summaries()

# 初始化sess
sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)  # 在這裏global_step被賦初值

# 指定監測結果輸出目錄
summary_writer = tf.train.SummaryWriter('/tmp/log/', sess.graph)

# 啓動迭代
for step in range(0, 10):
    s_val = sess.run(sum_ops)    # 獲取serialized監測結果:bytes類型的字符串
    summary_writer.add_summary(s_val, global_step=step)   # 寫入文件
    sess.run(increment_op)     # 計數器+1

調用tf.scalar_summary系列函數時,就會向默認的collection中添加一個Operation

再次回顧“零存整取”原則:創建網絡的各個層次都可以添加監測;在添加完所有監測,初始化sess之前,統一用tf.merge_all_summaries獲取。

查看

SummaryWriter文件中存儲的是序列化的結果,需要藉助TensorBoard才能查看。

在命令行中運行tensorboard,傳入存儲SummaryWriter文件的目錄:

tensorboard --logdir /tmp/log

完成後會提示:

You can navigate to http://127.0.1.1:6006

可以直接使用服務器本地瀏覽器訪問這個地址(本機6006端口),或者使用遠程瀏覽器訪問服務器ip地址的6006端口。

自定義

除了默認的集合,我們也可以自己創造collection組織對象。網絡損失就是一類適宜對象。

tensorflow中的Loss提供了許多創建損失Tensor的方式。

x1 = tf.constant(1.0)
l1 = tf.nn.l2_loss(x1)

x2 = tf.constant([2.5, -0.3])
l2 = tf.nn.l2_loss(x2)

創建損失不會自動添加到集合中,需要手工指定一個collection

tf.add_to_collection("losses", l1)
tf.add_to_collection("losses", l2)

創建完成後,可以統一獲取所有損失,losses是個Tensor類型的list:

losses = tf.get_collection('losses')

另一種常見操作把所有損失累加起來得到一個Tensor

loss_total = tf.add_n(losses)

執行操作可以得到損失取值:

sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)
losses_val = sess.run(losses)
loss_total_val = sess.run(loss_total)

實際上,如果使用TF-Slim包的losses系列函數創建損失,會自動添加到名爲”losses”的collection中。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章