tensorflow:不要在session中定義運算

最近在做項目時,總是會有程序崩潰的問題,系統也沒有任何提示。最後通過監控系統發現是內存溢出造成的。

追查下去,發現一段類似這樣的代碼,在session中調用tensorflow的api進行運算:

import tensorflow as tf
X = tf.constant([[1,2,3], [3,2,4]], dtype=tf.float32)
W = tf.constant([[1,1],[2,2],[3,3]], dtype=tf.float32)
bias = tf.constant([1, 2], dtype=tf.float32)
y = tf.nn.softmax(tf.matmul(X, W) + bias)

with tf.Session() as sess:

    for i in range(10):
        print(i)
        sess.run(tf.nn.softmax(tf.matmul(X, W) + bias))

    writer = tf.compat.v1.summary.FileWriter("./graph", sess.graph)
    writer.close()

使用tensorboard查看內存泄漏的原因:

將計算圖展開爲

當然,這裏只是展開了softmax,其他節點也可以類似展開。

可以看到,在session中定義計算節點,存在一個很大的風險,就是會在計算圖中產生新的圖節點,如果像我這樣使用for循環運算,那麼節點數會無限增加,注意不僅僅是softmax節點在增加,其他計算節點也在增加,這樣的開銷會越來越大,直至程序崩潰。

爲了解決這個問題,我們應該使用上面定義的y的等式,在進入session前就已經將計算圖定義好,在session中直接調用,而不是重新搭建。

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章