Tensorflow中的Lazy load問題

問題描述

用tensorflow訓練或者inference模型的時候,有時候會遇到運行越來越慢,最終內存被佔滿,導致電腦死機的問題,我們稱之爲內存溢出。出現這種問題很可能是因爲在一個session中,graph循環建立重複的節點所導致的Lazy load問題。

舉例說明

舉個例子,用tensorflow循環做多次加法運算,常見的做法是:

x = tf.Variable(10, name='x')
y = tf.Variable(20, name='y')
z = tf.add(x, y)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for _ in range(10):
        sess.run(z)

在session開始之前graph如下:
1

有可能有人想省點力氣,把z = tf.add(x, y) 加法操作直接寫到session中:

x = tf.Variable(10, name='x')
y = tf.Variable(20, name='y')
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for _ in range(10):
        sess.run(tf.add(x, y)) # 只在需要時才創建節點

在session開始之前graph如下:
2

可以發現在graph中看不到加法操作了,是不是真的簡化了呢?對單步操作可能是簡化了,但是當遇到如本例中的循環操作時,這種將tensorflow op寫到循環中的做法會產生Lazy load的問題,讓你的內存逐漸被佔滿,爲什麼?

原因解釋

循環結束之後我們將兩種操作方法各自graph中的節點通過print (tf.get_default_graph().as_graph_def())
命令打印出來,正常操作的graph protobuf如下:

node {
 name: "Add"
 op: "Add"
 input: "x/read"
 input: "y/read"
 attr {
 key: "T"
 value {
 type: DT_INT32
  }
 }
}

可見graph中只有一個加法操作的節點,而第二種將加法寫在session循環中的方法用同樣的命令將其graph protobuf打印出來,得到的是重複的10個上述節點。所以,如果循環不是10個,而是更多個,則會導致graph中的節點越來越多,最終導致內存溢出。

解決方法

將一些算術操作通過tf.placeholder寫在session外,session中通過feed_dict={x: X}填入數據。session裏如果有循環,切記不要在循環中進行tf操作。

Reference

  1. Stanford CS20si
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章