TensorFlow出現Found Inf or NaN global norm的排查和解決辦法

在訓練神經網絡的時候,由於一些原因會出現NaN或者Inf,致使訓練終止。在查閱相關資料之後,並且結合我出現的問題,做了一些總結。出現的代碼在TensorFlow 1.12.2版本可正常執行。

出現問題的原因

出現NaN或者Inf的原因一般可分爲以下三種

  1. 輸入數據有錯
  2. 出現了運算錯誤,如除數爲零,log0等
  3. 梯度爆炸

輸入數據有錯

訓練數據可能包含髒數據,在數據清洗時沒有清洗乾淨,導致錯誤數據輸入進模型。首先可以在輸入模型前,使用np.any(np.isnan(data))來判斷數據是否由nan。若沒有,要考慮到數據的實際約束,如在我查閱資料時,看到有人輸入數據包含最大值和最小值,但是錯誤數據的最大值和最小值反了,導致模型訓練出錯。這一部分要根據具體情況進行具體排查。

運算錯誤

檢查模型中除法的分母是否爲0,如果有0在結合實際情況進行修改。如果損失函數用到交叉熵,或者取log,也要注意0是否出現。如果判斷出有0了,可以使用tf.clip_by_value對值進行限制。

梯度爆炸

常見於模型設計不好,或者模型本身的原因。如RNN易發生梯度爆炸,更換爲LSTM可解決問題。或者模型採用了較大的學習速率,導致更新網絡參數出現問題。

排查建議

最簡單最方便的方法,首先調小學習速率,看看是否是由較高學習速率導致的。可以選擇將學習速率降低一半, 或者降低一個數量級。在多次嘗試之後,若不能解決問題,考慮其他情況。

檢查運算錯誤,主要是有除法運算和取log的地方。判斷是否有0出現,以及是否有0導致的問題,試着使用clip_by_value對數值進行限制。

最後,若還沒有解決問題,再檢查數據是否清洗乾淨。

另外,可以使用一些代碼來輔助檢查。在模型增加以下函數
在這裏插入圖片描述
如果懷疑模型在某個結點node出現了nan,則在模型中使用add_summary_var(node,name)。並且在模型的最後寫上
在這裏插入圖片描述
同時在sess.run部分加入self.merged_inside。這樣一旦出現NaN錯誤,TensorFlow會提醒出現問題的結點名稱(名稱即爲上面函數中指定的name)。如果添加了多個,則會在第一次出現nan的地方報錯,這樣可以幫助我們找到問題所在。

本人在使用這種方法以後,發現在經過RNN之後出現了問題,判斷可能是由於RNN導致的梯度爆炸。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章