ctc_loss_calculator.cc:144] No valid path found.或loss: inf

最近經常出現一個錯誤,在模型訓練的時候loss:inf,如果出現的不多的話還是可以接受的,但是一旦這個大量出現,模型就不能訓練了,損失也很難收斂,所以今天我終於把這個問題解決了,寫下來表示分享。
經過分析,是輸入長度和標籤長度之間的問題,網上說要求輸入長度要大於標籤長度,我看了一下我的輸入長度13,標籤長度10,符合要求,但是依然出現錯誤,我換了一個模型之後輸入長度14,標籤長度10,問題消失,得出結論,輸入長度要高於標籤長度一部分,至於高出多少,應該考慮識別的字符串中重複並且相鄰的字符數,簡單來說就是儘量的多一些吧,目前沒有分析增加輸入長度對性能的影響,至少肉眼感覺不出來,但是影響性能是肯定的。
下面說怎麼增加輸入長度,我們知道用CTCloss的時候需要有四個輸入,分別如下:
在這裏插入圖片描述
後兩個參數就是輸入長度和標籤長度,標籤長度肯定是沒法改的,這個需求是固定的,所以只能改input_length,後來我發現input_length是和網絡結構相關的,如圖:
在這裏插入圖片描述
可以看到,基礎網絡傳入CTC中的尺寸是(15,37),這個15和我們的input_length就有關係了,我這裏input_length=15-2(我也沒搞懂爲啥減2),37代表的是我識別的36個字符和一個空格,這個時候我只需要增加15這個數就好了,所以我在前邊增加了一個Reshape層,如圖:
在這裏插入圖片描述
經過Reshape之後變成了30,這樣我們再看CTCloss部分:
在這裏插入圖片描述
這裏的輸入尺寸變成了(30,37),input_length變成了28,label依然是10,錯誤沒有了,又可以浪了。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章