用於序列識別的聚合交叉熵
摘要
在本文中,我們從全新的角度提出了一種新的聚合交叉熵(ACE)方法,用於序列識別。 ACE損失函數表現出對CTC和注意機制的競爭性能,實現快得多(因爲它只涉及四個基本公式),更快的推理\反向傳播(大約O(1)並行),更少的存儲要求(沒有參數且可忽略運行時內存),方便使用(用ACE代替CTC)。 此外,所提出的ACE損失函數具有兩個值得注意的特性:(1)它可以通過將2D預測展平爲1D預測作爲輸入直接應用於2D預測。(2)它僅需要字符及其序列標註中的數字監督,它允許它超越序列識別,例如計數問題。該代碼可在 https://github.com/summerlvsong/Aggregation-CrossEntropy 上公開獲取。
1. 介紹
序列識別或序列標記是將從固定字母表中抽取的標籤序列分配給輸入數據序列,例如語音識別,場景文本識別和手寫文本識別,如圖1所示。最近在深度學習和新的結構上的提高使得系統架構可以處理一維(1D)和二維(2D)的預測問題。對於1D預測問題,網絡的最頂部特徵圖在垂直維度上摺疊以生成1D預測,因爲原始圖像中的字符通常是順序分佈的。典型的例子是常規場景文本識別,在線/離線手寫文本識別和語音識別。對於2D預測問題,輸入圖像中的字符分佈在特定的空間結構中。例如,在數學表達式識別中,相鄰字符之間存在高度複雜的空間關係。在段落級文本識別中,字符通常逐行分佈,而在不規則場景文本識別中,它們通常分佈在側視或曲角模式中。
對於序列識別問題,傳統方法通常需要爲輸入序列中的每個片段或時間步驟分離訓練目標,導致麻煩的預分割和後處理階段 。最近出現的CTC 和注意機制通過規避輸入圖像與其相應標籤序列之間的先前對齊,顯着緩解了這種順序訓練問題。然而,儘管基於 CTC 的網絡在一維預測問題中表現出了顯着的性能,但 基礎方法是複雜的 ;而且, 它的前向後向算法實現很複雜,導致大量的計算消耗 。此外, CTC很難應用於2D預測問題。 同時,注意機制依賴於其注意模塊來進行標籤對齊,從而導致額外的存儲需求和計算消耗。 正如Bahdanau等人所指出的那樣。 識別模型很難通過注意機制從頭開始學習,因爲ground truth 字符串與注意力預測之間存在不一致,特別是對於較長的輸入序列。Bai等人也認爲錯位問題可能會混淆和誤導訓練過程,從而使訓練成本高昂並降低識別準確度。儘管注意機制可以適用於2D預測問題,但在內存和時間消耗方面卻令人望而卻步,
通過上述觀察,我們爲序列識別問題提出了一種新的聚合交叉熵(ACE)損失函數,詳見圖2。 鑑於網絡的預測,ACE損失包括三個簡單階段: (1)沿時間維度聚合每個類別的概率; (2)將累積結果和標籤標註標準化爲所有類別的概率分佈; (3)使用交叉熵比較這兩個概率分佈。 所提出的ACE損失函數的優點可歸納如下:
- 由於其簡單性,ACE損失功能實現得更快(四個基本公式),更快地推斷和反向傳播(大約 並行),更少的內存要求(沒有參數和基本運行時內存),以及與CTC和注意機制相比,使用方便(簡單地用ACE替換CTC)。在表5,3.4部分和4.4部分說明。
- 儘管簡單,ACE損失功能實現了對CTC和注意機制的競爭性能,正如在常規不規則場景文本識別和手寫文本識別問題的實驗中所建立的那樣。
- 通過將2D預測平坦化爲1D預測,ACE損失函數可以適應2D預測問題,如在不規則場景文本識別和計數問題的實驗中所驗證的。
- ACE損失函數不需要用於監督的實例順序信息,這使得它能夠超越序列識別,例如計數問題。
2. 相關工作
2.1 CTC
流行的CTC損失的優勢首先在語音識別和在線手寫文本識別中得到證實。 最近,提出了一種集成的CNN-LSTM-CTC模型來解決場景文本識別問題。還有一些方法旨在擴展應用中的CTC; 例如,張等人。 提出了一種適用於CTC的擴展CTC(ECTC)目標函數,即使只有字級註釋可用,也可以訓練基於RNN的音素識別器。黃等人開發了一種基於期望最大化的在線CTC算法,該算法允許RNN以無限長的輸入序列進行訓練,無需預分割或外部重置。然而,CTC的計算過程非常複雜且耗時,並且當應用於2D問題時需要大量工作來重新排列feature map 和標註。
2.2 注意力機制
注意機制首先在機器翻譯中提出,以使模型能夠自動搜索源句的部分以進行預測。然後,該方法迅速在諸如(視覺)問答,圖像標題生成,語音識別和場景文本識別等應用中變得流行。最重要的是,注意力機制也可以應用於2D預測,例如數學表達式識別和段落識別。 然而,注意機制依賴於複雜的注意力模塊來實現其功能,從而產生額外的網絡參數和運行時間。此外,缺失或多餘的字符很容易導致錯位問題,混淆和誤導訓練過程,從而降低識別準確度。
3. 聚合交叉熵(ACE)
形式上,給定來自訓練集 的輸入圖像 及其序列標註 ,序列識別問題的 一般損失函數評估 在模型參數 下以圖像 爲條件的長度 的標註 的概率如下 :
其中 代表預測序列的第 個位置的預測字符 的概率。因此,這個問題是估計基於模型預測 的一般損失函數,其中 , 是字符集合, 是空格。然而,在CTC損失函數出現之前,直接估計 非常具有挑戰性的。CTC損失函數使用前向後向算法優雅地計算 ,這消除了對預分割數據和外部後處理的需要。注意機制通過基於其注意模塊直接預測 來提供估計一般損失函數的替代解決方案。然而,CTC的前向 - 後向算法是高度複雜和耗時的,而注意機制需要額外的複雜網絡以確保注意力預測和標註之間的對齊。
(左)通常,1D和2D預測分別由集成的CNN-LSTM和FCN模型生成。 對於ACE損失函數,2D預測進一步展平爲1D預測 。在聚合期間,所有時間點的1D預測都是爲每個類獨立累積的,根據 。在歸一化之後,將預測 與GroundTruth 一起用於基於交叉熵的損失估計。(右)一個簡單的例子表明ACE損失函數的標籤的生成。 表明在 cocacola中有兩個 “a”。
在本文中,我們提出ACE損失函數來估計基於模型預測 的一般損失函數。在公式(1)中,通過最大化序列標註的每個位置處的預測(例如 ),可以最小化一般損失函數。然而,基於 直接計算 是具有挑戰性的,因爲註釋中的第 個字符與模型預測 之間的對齊不清楚。因此,不是精確地估計概率 ,而是通過僅監督每個類的累積概率來簡化問題;不考慮標註的序列順序問題。例如,如果一個類在標註中出現兩次,我們要求它在 時間步長上的累積預測概率恰好是兩個,預計它的兩個相應的預測接近於一個。因此,我們可以通過 要求網絡精確預測標註中每個類的字符數來最小化一般損失函數 ,如下所示:
其中 代表字符 在序列標註 中出現的次數。請注意,這個新的損失函數不需要字符順序信息,只需要用於監督的類及其數量。
3.1 基於迴歸的ACE損失函數
現在,問題是將模型預測 連接到每個類的數字預測。我們提出通過把 時間步長的第 個字符的概率求和來計算每個類 的數量,例如 。如圖2中的聚合所示。注意:
因此,我們從迴歸問題的角度調整損失函數(公式(2))如下:
另請注意,預計總共 預測會產生零排放。 因此,我們有 。
爲了找到每個例子 的梯度,我們首先根據網絡輸出 區分 :
其中 。 回想一下,對於 函數,我們有:
其中,如果 ,則 ,否則爲 。現在,我們可以將損失函數與 區分開來,通過輸出層反向傳播梯度:
3.1 梯度消失問題
從公式(7),我們觀察到基於迴歸的ACE損失(公式(4))在反向傳播方面不方便。在早期訓練階段,我們有 。 因此對於大的詞彙序列識別問題, 可以忽略不計,其中 是大的(例如HCTR問題中的7357)。雖然公式(7)中的其他項具有可接受的反向傳播速度,但是梯度將通過 和 縮放到非常小的尺寸,導致梯度消失問題。
3.2 基於交叉熵的ACE損失函數
爲了防止梯度消失問題,有必要抵消由公式(7)中的 函數引入的項 的影響。我們從信息理論中借用交叉熵的概念,信息理論旨在測量兩個概率分佈之間的“距離”。因此,我們把第 個字符 的累計概率標準化爲 ,把字符數量 標準化爲 。然後,在 和 之間的交叉熵可以表示爲:
關於 激活函數之前的 的損失函數導數具有以下形式:
3.2.1 討論
接下來,我們討論更新後的損失函數怎麼解決梯度消失的問題:
- 在早期訓練階段, 在所有時間步長具有大致相同的數量級。因此,標準化後的累計概率 也有和 一致的數量級。即, ;因此通過第 類的梯度就是 。因此,梯度可以通過出現在序列標註 中的字符直接向後傳播到 。除此之外,當 時,例如 ;相應的梯度大約爲 ,這將鼓勵模型作出更大的預測 ,而不在 中出現的字符變小。這是我們的初衷。
- 在後期訓練階段,在後面的訓練階段,只有少數預測 會非常大,而其他預測則足夠小,可以省略。在這種情況下,預測 將佔據 的大部分,就會有 。因此,當 時,梯度可以直接反向傳播到識別網絡。
3.3 二維預測
在諸如具有圖像級標註的不規則場景文本識別的2D預測問題中,定義字符之間的空間關係是具有挑戰性的。字符可能是多條線排列,沿彎曲或傾斜方向排列,或者甚至以隨機方式分佈。幸運的是,提出的ACE損失函數可以推廣到2D預測問題,因爲它不需要序列學習過程的字符順序信息。
假設,輸出的2D預測 高 ,寬 ,那麼在第 行和第 列的預測可以表示爲 。需要計算 和 來進行邊界調整, ,。然後,2D預測的損失函數可以表示爲:
在我們試驗中,我們直接把2D預測 展平爲1D預測 ,其中 ,然後使用公式(8)計算最終損失。
3.4 實現和複雜性分析
實現 如圖2所示, 代表 損失函數的標註;這裏, 代表在序列標註 中字符 出現的次數。在圖2中描述了序列標註 cocacola轉換爲 的標註這麼一個簡單的例子。總的來說,給定模型預測 和它的標註 ,基於交叉熵的 損失函數的關鍵實現包括四個基本的公式:
- 通過對全部時間的第 類的概率求和,計算每一個類的字符數量。
- 標準化累加的概率。
- 標準化標註。
- 估計在 和 之間的交叉熵。
在實際工作中,模型預測 通常是通過集成的 CNN-LSTM模型(1D預測)或者FCN模型(展平的2D預測)提供的。也就是說,ACE的輸入假設與CTC的輸入假設相同。因此,提出的ACE可以很方便的通過代替框架中的CTC層來應用。
複雜性分析 ACE損失函數的總體計算是基於上述的四個公式實現的,分別有 的計算複雜度。因此,ACE損失函數的計算複雜度是 。但請注意,這四個公式中的逐元素乘法,除法和 log 操作可以在 GPU 以 複雜度並行實現。相反,基於前向後向算法的CTC的實現有一個 的複雜度。因爲CTC的前向變量 和後向變量 依賴於之前的結果(例如 和 )計算當前的輸出,CTC很難在時間維度上並行加速。更多的,CTC的基本操作也是非常複雜的,導致總體消耗時間大於ACE。關於注意力機制,其計算複雜度與“注意力”的時間成比例。然而,每次注意模塊的計算複雜度已經具有與CTC相似的量級。
從內存消耗的角度來看,提出的ACE損失函數幾乎不需要內存消耗,因爲可以根據四個基本公式直接計算ACE損失結果。 但是,CTC需要額外的空間來保存與時間步長 和序列標註長度成比例的前向後向變量。同時,注意機制需要額外的模塊來實現“注意力”。 因此,其內存消耗量明顯大於CTC和ACE。
總的來看,與CTC和注意力比較,提出的ACE損失函數在計算複雜性和內存需求方面都表現出顯着的優勢。
4. 性能評估
在我們的實驗中,使用了三個任務來評估所提出的ACE損失函數的有效性,包括場景文本識別,離線手寫中文文本識別和計算日常場景中的物體。對於這些任務,我們估計了1D和2D預測的ACE損失,其中1D表示最終預測是T 預測序列,2D表示最終特徵圖具有H×W的2D預測。
4.1 場景文本識別
由於背景,外觀,分辨率,文本字體和顏色的大變化,場景文本識別經常遇到問題,使其成爲具有挑戰性的研究課題。在本節中,我們通過利用此任務中豐富性和多樣性的測試基準來研究場景文本識別的一維和二維預測。
4.1.1 數據集
在場景文本識別使用了兩種數據集:常規的文本數據集,例如 IIIT5K-Words,Street View Text,ICDAR2003和ICDAR2013 和不規則的文本數據集,例如SVT-Perspective,CUTE80和ICDAR2015。正規的數據集使用ACE損失函數的1D預測來評估,不規則的文本數據集使用2D預測來評估。
IIIT5K-Words(IIIT5K)包含3000張剪切的單詞圖像用來測試。
Street View Text(SVT)從 Google Street View收集來的,包含647張單詞圖像。其中許多都被噪音和模糊嚴重破壞,或者分辨率非常低。
ICDAR2003(IC03)包含251張場景圖像,使用文本邊框標註。包含867張裁剪圖像。
ICDAR2013(IC13)繼承了IC03中的大部分樣本。包含1015個裁剪文本圖像。
SVT-Perspective(SVT-P)包含639張裁剪圖像用來測試,是從Google街景中的側視角快照中選擇來的。 因此,大多數圖像都是透視扭曲的。 每個圖像都與50個單詞的詞典和一個完整的詞典相關聯。
CUTE80(CUTE)包含80幅自然場景拍攝的高分辨率圖像。 它專門用於彎曲文本識別。 該數據集包含288個裁剪的自然圖像用於測試。 沒有詞典相關聯。
ICDAR 2015(IC15)包含2077張剪裁圖像,包括超過200張不規則文本。
4.1.2 實現細節
對於常規數據集上的1D序列識別,我們的實驗基於CRNN網絡,僅對Jaderberg等人發佈的800萬個合成數據進行了訓練。對於不規則數據集上的2D序列識別,我們的實驗基於ResNet-101,其中conv1變爲3×3,步長爲1,conv4_x作爲輸出。訓練數據集由Jaderberg等人發佈的800萬個合成數據和從8萬張圖像中裁剪出來的4萬個合成實例(不包括包含非字母數字字符的圖像)組成。輸入圖像標準化爲 (96, 100),最終的2D預測爲 (12,13),如圖5所示。爲了解碼2D預測,我們通過按照從左到右和從上到下的順序連接每個列來平坦化2D預測,然後按照一般過程對平坦的1D預測進行解碼。
在我們的實驗中,我們觀察到直接將輸入圖像標準化爲(96,100)的大小會使網絡訓練過程過載。因此,我們訓練另一個網絡來預測文本圖像中的字符編號,並相對於字符編號對文本圖像進行標準化,以使字符大小保持在可接受的限度內。
4.1.3 實驗結果
爲了研究迴歸和交叉熵對ACE損失函數的作用,我們使用常規場景文本數據集進行了一維預測的實驗,詳見表1和圖3。因爲在場景文本識別中只有37個類,所以在公式(7)中的項 的負面影響沒有HCTR問題那麼嚴重(7357類)。如圖3所示,基於迴歸的ACE損失,網絡可以收斂但速度很慢,可能是由於梯度消失問題。基於交叉熵的ACE損失,WER和CER在早期訓練階段以相對較高的速率和更平滑的方式發展,並在隨後的訓練階段獲得明顯更好的收斂結果。表1清楚地揭示了基於交叉熵的ACE損失函數優於基於迴歸的ACE損失函數的優越性。因此,我們對所有剩餘的實驗使用基於交叉熵的ACE損失函數。此外,使用相同的網絡設置(CRNN)和訓練集(800萬個合成數據),所提出的ACE損失函數表現出與CTC以前的工作相當的性能。
爲了驗證所提出的ACE損失與字符順序的獨立性,我們使用ACE,CTC和注意力在四個數據集上進行實驗;標註的字符順序以不同的比例隨機打亂,如圖4所示。可以發現,注意力和CTC的性能隨着打亂比例的增加在下降。特別地,因爲錯誤問題很容易誤導訓練過程中的注意力,所以注意力比CTC更敏感,相比之下,所提出的ACE損失函數對於打亂比例的所有設置都表現出類似的識別結果,這是因爲它只需要類別及其數量進行監督,完全省略字符順序信息。
對於不規則的場景文本,我們用2D預測進行了文本識別實驗。在表2中,我們提供了和以前的方法的比較,這些方法僅考慮識別模型,沒有進行公平比較的整改。如表2所示,所提出的ACE損失函數在數據集CUTE和IC15上表現出優異的性能,特別是在CUTE上,絕對誤差減少5.8%。這是因爲數據集CUTE是專門用於彎曲文本識別的,因此,充分展示了ACE損失功能的優勢。對於數據集SVT-P,我們的解碼結果不如Yang等人的有效。這是因爲數據集SVT-P中的大量圖像具有非常低的分辨率,這對語義上下文建模產生了非常高的要求。但是,我們的網絡僅基於CNN,既沒有LSTM / MDLSTM也沒有注意機制來利用高級語義上下文。然而,值得注意的是,我們的識別模型在使用詞彙時獲得了最高的結果,語義上下文可以訪問。這再次驗證了所提出的ACE損失函數的穩健性和有效性。
在圖5中,我們使用ACE損失函數提供由識別模型處理的一些真實圖像。 首先將原始文本圖像標準化並放置在形狀(96,100)的空白圖像的中心。我們觀察到,在識別之後,2D預測呈現出與原始文本圖像中的字符高度相似的空間分佈,這暗示了所提出的ACE損失函數的有效性。
左邊兩列代表原始圖像和使用 (96, 100) 標準化後的版本。第三列展示了對文本圖像的2D預測。在最右邊的一列,我們重疊輸入和預測圖像,並觀察2D空間中類似的字符分佈。
4.2 離線手寫中文文本識別
由於其龐大的字符集(7,357個類),多樣化的寫作風格和人物風格問題,離線HCTR問題非常複雜且難以解決。因此,評估一維預測中ACE損失的穩健性和有效性是一個有利的試驗平臺。
4.2.1 實現細節
對於離線HCTR問題,我們的模型使用CASIA-HWDB 數據集進行訓練,並使用標準基準ICDAR 2013競賽數據集進行測試。
對於HCTR問題,具有長度爲70的預測序列的網絡架構規定如下:(126,576)Input - 8C3 - MP2 - 32C3 - MP2 - 128C3 - MP2 - 5*256C3 - MP2 - 512C3 - 512C3 - MP2 - 512C2 - 3*512ResLSTM - 7357FC - Output,其中 代表核數量 ,核大小 的卷積層, 表示卷積核大小 的最大池化層, 是有 個核的全連接層,ResLSTM是 residual LSTM。HCTR問題的評估標準是ICDAR 2013競賽規定的正確率(CR)和準確率(AT)。
4.2.2 實驗結果
在表3中,我們提供了ACE損失和之前方法的比較。證明了,提出的ACE損失函數比之前的方法有更好的表現,包括基於模型[34] [37]的MDLSTM,基於模型[10]的HMM和沒有語言模型的超分割方法[27, 44, 45, 48]。和場景文本識別比較,手寫中文文本識別問題具有獨特的挑戰,例如大的字符集(7357類)和字符書寫問題。因此,ACE損失函數優於先前方法的優越性能可以正確驗證其序列識別問題的魯棒性和通用性。
4.3 計算日常場景中的目標
計算自然日常圖像中目標類的實例數量通常會遇到複雜的現實生活情況,例如,計數,外觀和對象尺度的較大差異。因此,我們在日常場景中計算目標的問題上驗證了ACE損失函數,以證明其通用性。
4.3.1 實現細節
作爲多標籤目標分類和目標檢測任務的基準,PASCAL VOC數據集包含每個圖像的類別標籤,以及可以轉換爲對象編號標籤的邊界框註釋。在我們的實現中,我們通過對零處的計數進行閾值處理並將預測舍入到最接近的整數來累積類別 的預測以獲得 。給定一個類別 和一張圖像 的預測和ground truth 計數 , 則 ,$RMSE_k = \sqrt{\frac{1}{N} \sum_{i=1}^N \frac{(\hat{c}{ik} - c{ik})^2}{c_{ik} + 1}} $。
4.3.2 實驗結果
表4顯示了提出的ACE損失函數與之前用於PASCAL VOC 2007測試數據集的方法之間的比較,用於計算日常場景中的目標。提出的ACE損失函數優於先前的掃視和子化方法[6],相關損失方法[40]和Always-0方法(預測最常見的ground truth計數)。結果顯示了ACE損失函數的通用性,因爲它可以容易地應用於除序列識別之外的問題,例如計數問題,需要最小的領域知識。
在圖6中,我們在ACE損失下提供由計數模型處理的真實圖像。如圖所示,我們使用ACE損失訓練的計數模型設法“關注”關鍵目標發生的位置。與文本識別問題不同,其中使用ACE損失函數訓練的識別模型傾向於對字符進行預測,使用ACE損失函數訓練的計數模型在目標上提供更均勻的預測分佈。而且,它爲對象的不同部分分配不同級別的“注意力”。例如,當觀察圖片中的紅色時,我們注意到計數模型更注重人的面部。這種現象與我們的直覺相對應,因爲面部是個體中最獨特的部分
前四列圖片顯示我們的模型正確識別的示例。右上角圖像已正確識別,但標註是錯誤的。
4.4 複雜度分析
在表5中,我們將ACE的參數,運行時內存和運行時間與CTC和注意力的參數,運行時內存和運行時間進行比較。在12GB內存的單個NVIDIA TITAN X圖形卡上使用minibatch 64和模型預測長度T = 144執行結果。與CTC類似,提出的ACE不需要任何參數來完成其功能。由於其簡單性,ACE需要運行時內存,比CTC和注意力少五倍。此外,它的速度至少是CTC和注意力的30倍。值得注意的是,憑藉所有這些優勢,提出的ACE實現的性能與CTC和注意力相當或更高(括號中的標籤提供了錯誤的預測)。
5. 結論
本文提出了一種新穎,直觀的ACE損失函數,用於解決序列識別問題,對CTC和注意力具有競爭性能。由於其簡單性,ACE損失函數很容易通過簡單地用ACE替換CTC,快速實現只有四個基本公式,快速推斷並在大約 並行傳播,並顯示邊際內存要求。還研究了ACE損失函數的兩個有效特性:(1)它可以很容易地處理邊緣適應的2D預測問題(2)它不需要用於監督的字符順序信息,這允許它超越序列識別問題,例如計數問題。