問題描述
用tensorflow訓練或者inference模型的時候,有時候會遇到運行越來越慢,最終內存被佔滿,導致電腦死機的問題,我們稱之爲內存溢出。出現這種問題很可能是因爲在一個session中,graph循環建立重複的節點所導致的Lazy load問題。
舉例說明
舉個例子,用tensorflow循環做多次加法運算,常見的做法是:
x = tf.Variable(10, name='x')
y = tf.Variable(20, name='y')
z = tf.add(x, y)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for _ in range(10):
sess.run(z)
在session開始之前graph如下:
有可能有人想省點力氣,把z = tf.add(x, y)
加法操作直接寫到session中:
x = tf.Variable(10, name='x')
y = tf.Variable(20, name='y')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for _ in range(10):
sess.run(tf.add(x, y)) # 只在需要時才創建節點
在session開始之前graph如下:
可以發現在graph中看不到加法操作了,是不是真的簡化了呢?對單步操作可能是簡化了,但是當遇到如本例中的循環操作時,這種將tensorflow op寫到循環中的做法會產生Lazy load的問題,讓你的內存逐漸被佔滿,爲什麼?
原因解釋
循環結束之後我們將兩種操作方法各自graph中的節點通過print (tf.get_default_graph().as_graph_def())
命令打印出來,正常操作的graph protobuf如下:
node {
name: "Add"
op: "Add"
input: "x/read"
input: "y/read"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
可見graph中只有一個加法操作的節點,而第二種將加法寫在session循環中的方法用同樣的命令將其graph protobuf打印出來,得到的是重複的10個上述節點。所以,如果循環不是10個,而是更多個,則會導致graph中的節點越來越多,最終導致內存溢出。
解決方法
將一些算術操作通過tf.placeholder
寫在session外,session中通過feed_dict={x: X}
填入數據。session裏如果有循環,切記不要在循環中進行tf操作。