Keras中循環使用K.ctc_decode內存不釋放問題

問題描述

類似問題:

  1. K.ctc_decode() memory not released#6770
  2. K.ctc_decode() memory not released#9011
  3. Keras using Lambda layers error with K.ctc_decode

如下一段代碼,在多次調用了K.ctc_decode時,會發現程序佔用的內存會越來越高,執行速度越來越慢。

data = generator(...)
model = init_model(...)
for i in range(NUM):
    x, y = next(data)
    _y = model.predict(x)
    shape = _y.shape
    input_length = np.ones(shape[0]) * shape[1]
    ctc_decode = K.ctc_decode(_y, input_length)[0][0]
    out = K.get_value(ctc_decode)

原因

每次執行ctc_decode時都會向計算圖中添加一個節點,這樣會導致計算圖逐漸變大,從而影響計算速度和內存。
PS:有資料說是由於get_value導致的,其中也給出瞭解決方案。
但是我將ctc_decode放在循環體之外就不再出現內存和速度問題,這是否說明get_value影響其實不大呢?

解決方案

  1. 通過K.function封裝K.ctc_decode,只需初始化一次,只向計算圖中添加一個計算節點,然後多次調用該節點(函數)
data = generator(...)
model = init_model(...)
x = base_model.output    # [batch_sizes, series_length, classes]
input_length = KL.Input(batch_shape=[None], dtype='int32')
ctc_decode = K.ctc_decode(x, input_length=input_length * K.shape(x)[1])
decode = K.function([model.output, input_length], [ctc_decode[0][0]])
for i in range(NUM):
    x, y = next(data)
    out = decode([x, np.ones(1)])
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章