【Text Transcriptor】訓練CRNN時,關於ctc_loss的幾點注意事項

這個ctc_loss很魔性,訓練CRNN虐了我幾個來回。

我的數據集圖片大小不一,我是先等比例縮小到固定高度爲32,寬度不定。

常見三個問題:

1.CTC Loss Error: invalidArgumentError: Not Enough time for target transition sequence.

2.CTC Loss Error: InvalidArgumentError: sequence_length(b) <= time

3.ctc_loss error “No valid path found.” (這個錯誤對模型收斂沒有很大影響,只是出錯的那一個batch參數沒有更新優化。如果這個錯誤很少,可以忽略。如果這個錯誤很多的話就建議用下面方法優化一下訓練集。)

導致這三個問題的原因,就是label_length 和input_length的取值問題。

1. CRNN一個主要優點就是可以識別任意長度的圖片。在訓練的時候,先統一將圖片padding到一個固定的很長的寬度。然後input_length設置爲你等比例縮小後,padding之前的圖片的寬除以四。部分代碼如下:

Img = Image.open(imagepath).convert('L')
ResizedImg = cv2.resize(Img, (int(Img.shape[1] * (32 / Img.shape[0])), 32))
input_length[i] = ResizedImg.shape[1] // 4

2. label_length很簡單理解,就是ground truth的長度。

3. 如果你以爲這樣就完事大吉可以訓練你就錯了。因爲你的圖片可能有不合格的存在。導致問題3出現,loss變爲inf。

4. 所以在訓練前,應該過濾一遍所有訓練集和驗證集的圖片。ctc_loss在計算預測結果和真值的loss的時候,會在你真值label中重複的字符之間插入空符,所以必須將label_length加上空符個數大於input_length的圖片刪除掉。而代碼中的2,是我考慮有可能在label的開頭和末尾存在空符。(我並沒有驗證這個想法,只是爲了保險起見。)舉個例子,你圖片高度爲32,寬度爲160,那麼input_length=40。label='abbbccddddcccaa',label_length=15,經過計算repreat_number爲2(bbb)+1(cc)+3(dddd)+2(ccc)+1(aa),然後再加上開頭結果的空符數2,最終等於11。也就是說必須滿足label_length(15)+repreat_number(11)<=input_length(40)的圖片纔是合格的圖片。部分代碼如下:

Img = np.array(Image.open(ImgRootPath + '/' + imgName).convert('L'))
ResizedImg = cv2.resize(Img, (int(Img.shape[1] * (32 / Img.shape[0])), 32))
l = [len(list(g)) for k, g in itertools.groupby(Label)]
repeat_number = 0
for n in l:
    if n > 1:
        repeat_number += (n - 1)
input_length = ResizedImg.shape[1] // 4 
if len(Label)+repeat_number+2 > input_length:
    continue

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章