CTC Loss和Focal CTC Loss

最近一直在做手寫體識別的工作,其中有個很重要的loss那就是ctc loss,之前在文檔識別與分析課程中學習過,但是時間久遠,早已忘得一乾二淨,現在重新整理記錄下

本文大量引用了- CTC Algorithm Explained Part 1:Training the Network(CTC算法詳解之訓練篇),只是用自己的語言理解了一下,原論文:Connectionist Temporal Classification: Labelling UnsegmSequence Data with Recurrent Neural Networ

解決的問題

套用知乎上的一句話,CTC Loss要解決的問題就是當label長度小於模型輸出長度時,如何做損失函數。
一般做分類時,已有的softmax loss都是模型輸出長度和label長度相同且嚴格對齊,而語音識別或者手寫體識別中,無法預知一句話或者一張圖應該輸出多長的文字,這時做法有兩種:seq2seq+attention機制,不限制輸出長度,在最後加一個結束符號,讓模型自動和gt label對齊;另一種是給定一個模型輸出的最大長度,但是這些輸出並沒有對齊的label怎麼辦呢,這時就需要CTC loss了。
在這裏插入圖片描述

輸出序列的擴展

在這裏插入圖片描述

所以,如果要計算?(?│?),可以累加其對應的全部輸出序列o (也即映射到最終label的“路徑”)的概率即可,如下圖。

在這裏插入圖片描述

前向和後向計算

由於我們沒有每個時刻輸出對應的label,因此CTC使用最大似然進行訓練(CTC 假設輸出的概率是(相對於輸入)條件獨立的)
給定輸入xx,輸出序列 oo 的條件概率是:
p(πx)=yπtt,πLT p(\pi|x) = \prod y^t_{\pi_t}, \forall \pi \in L^{\prime T}
πt\pi _t 是序列 oo 中的一個元素,yy爲模型在所有時刻輸出各個字符的概率,shape爲T*C(T是時刻,提前已固定。C是字符類別數,所有字符+blank(不是空格,是空) ,yπtty^t_{\pi_t} 是模型t時刻輸出爲πt\pi _t的概率

我們模型的目標就是給定輸入x,使得能映射到最終label的所有輸出序列o的條件概率之和最大,該條件概率就是p(πx)p(\pi|x),和模型的輸出概率yy直接關聯

那麼我們如何計算這些條件概率之和呢?首先想到的就是暴力算法,一一找到可以映射到最終label的所有輸出序列,然後概率連乘最後相加,但是很耗時,有木有更快的做法?聯繫一下HMM模型中的前向和後向算法,它就是利用動態規劃求某個序列出現的概率,和此處我們要計算某個輸出序列的條件概率很相似
比如HMM模型中,我們要求紅白紅出現的概率,我們就可以利用動態規劃的思想,因爲紅白紅包含子問題紅白的產生,紅白包含子問題紅的產生,參考引用的圖片。
而這裏我們以apple這個label都可以由哪些輸出序列映射過去爲例(T爲8):
其中的一種 _ _ a p _ p l e
在這裏插入圖片描述
當然其他也可以如 a p p _ p p l e,但是考慮到我們最終對輸出序列的處理(兩個空字符之間的重複元素會去除,字符是從左到右的,且是依次的),我們的路徑(狀態轉移)不是隨便的,根據這樣的規則,我們可以找到所有可以映射到apple的輸出序列

在這裏插入圖片描述
很明顯可以看到這和HMM很像,包含很多相同子問題,可以用動態規劃做

定義在時刻t經過節點s的全部前綴子路徑的概率總和爲前向概率 αt(s)\alpha_t (s),如α3(4)\alpha_3 (4)爲在時刻3所有經過第4個節點的全部前綴子路徑的概率總和: α3(4)\alpha_3 (4) = p(_ap) + p(aap) + p(a_p) + p(app),該節點爲p
在這裏插入圖片描述

類似的定義在時刻t經過節點s的全部後綴子路徑的概率總和爲前向概率 βt(s)\beta_t (s),如β6(8)\beta_6 (8)爲在時刻6所有經過第8個節點的全部後綴子路徑的概率總和: β3(4)\beta_3 (4) = p(lle) + p(l_e) + p(lee) + p(le_),該節點爲l
在這裏插入圖片描述

總結

在這裏插入圖片描述

Focal CTC Loss

在這裏插入圖片描述
在這裏插入圖片描述
在這裏插入圖片描述
在這裏插入圖片描述
實現
在這裏插入圖片描述

參考論文 Focal CTC Loss for Chinese Optical Character Recognition on Unbalanced Datasets


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章