pytorch內置torch.nn.CTCLoss

一、開篇簡述

CTC 的全稱是Connectionist Temporal Classification,中文名稱是“連接時序分類”,這個方法主要是解決神經網絡label 和output 不對齊的問題(Alignment problem),其優點是不用強制對齊標籤且標籤可變長,僅需輸入序列和監督標籤序列即可進行訓練,目前,該方法主要應用於場景文本識別(scene text recognition)、語音識別(speech recognition)及手寫字識別(handwriting recognition)等工程場景。以往我們在百度上搜索pytorch + ctc loss得到的結果基本上warp-ctc的使用方法,warp-ctc是百度開源的一個可以應用在CPU和GPU上高效並行的CTC代碼庫,但是爲了在pytorch上使用warp-ctc我們不僅需要編譯其源代碼還需要進行安裝配置,使用起來着實麻煩。而在Pytorch 1.0.x版本內早就有內置ctc loss接口了,我們完全可以直接使用,只是很少有資料介紹如何使用該API。因此,本篇文章結合我個人工程實踐中的經驗介紹我在pytorch中使用其內置torch.nn.CTCLoss的方法,但不會對ctc loss原理進行展開,期望能給大家在工程實踐中使用torch.nn.CTCLoss帶來幫助!

二、CTCLoss接口使用說明

第一步,獲取CTCLoss()對象

ctc_loss = nn.CTCLoss(blank=len(CHARS)-1, reduction='mean')

類初始化參數說明:

blank:空白標籤所在的label值,默認爲0,需要根據實際的標籤定義進行設定;

reduction:處理output losses的方式,string類型,可選’none’ 、 ‘mean’ 及 ‘sum’,’none’表示對output losses不做任何處理,’mean’ 則對output losses取平均值處理,’sum’則是對output losses求和處理,默認爲’mean’ 。

第二步,在迭代中調用CTCLoss()對象計算損失值

loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)

CTCLoss()對象調用形參說明:

log_probs:shape爲(T, N, C)的模型輸出張量,其中,T表示CTCLoss的輸入長度也即輸出序列長度,N表示訓練的batch size長度,C則表示包含有空白標籤的所有要預測的字符集總長度,log_probs一般需要經過torch.nn.functional.log_softmax處理後再送入到CTCLoss中;

targets:shape爲(N, S) 或(sum(target_lengths))的張量,其中第一種類型,N表示訓練的batch size長度,S則爲標籤長度,第二種類型,則爲所有標籤長度之和,但是需要注意的是targets不能包含有空白標籤;

input_lengths:shape爲(N)的張量或元組,但每一個元素的長度必須等於T即輸出序列長度,一般來說模型輸出序列固定後則該張量或元組的元素值均相同;

target_lengths:shape爲(N)的張量或元組,其每一個元素指示每個訓練輸入序列的標籤長度,但標籤長度是可以變化的;

舉個具體例子說明如何使用CTCLoss(),如下爲CTCLoss在車牌識別裏面的應用:

比如我們需要預測的字符集如下,其中’-‘表示空白標籤;

CHARS = ['京', '滬', '津', '渝', '冀', '晉', '蒙', '遼', '吉', '黑',
         '蘇', '浙', '皖', '閩', '贛', '魯', '豫', '鄂', '湘', '粵',
         '桂', '瓊', '川', '貴', '雲', '藏', '陝', '甘', '青', '寧',
         '新',
         '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
         'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K',
         'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V',
         'W', 'X', 'Y', 'Z', 'I', 'O', '-'
         ]

因爲空白標籤所在的位置爲len(CHARS)-1,而我們需要處理CTCLoss output losses的方式爲‘mean’,則需要按照如下方式初始化CTCLoss類:

ctc_loss = nn.CTCLoss(blank=len(CHARS)-1, reduction=’mean’)

我們設定輸出序列長度T爲18,訓練批大小N爲4且訓練數據集僅有4張車牌(爲了方便說明)如下,總的字符集長度C如上面CHARS所示爲68:

BVjUvm.pnguploading.4e448015.gif轉存失敗重新上傳取消《如何優雅的使用pytorch內置torch.nn.CTCLoss的方法》

那麼我們在訓練一次迭代中打印各個輸入形參得出如下結果:

1)log_probs由於數值比較多且爲神經網絡前向輸出結果,我們僅打印其shape出來,如下:

torch.Size([18, 4, 68])

2)打印targets如下,表示這四張車牌的訓練標籤,根據target_lengths劃分標籤後可分別表示這四張車牌:

tensor([18, 45, 33, 37, 40, 49, 63, 4, 54, 51, 34, 53, 37, 38, 22, 56, 37, 38,33, 39, 34, 46, 2, 41, 44, 37, 39, 35, 33, 40])

3)打印target_lengths如下,每個元素分別指定了按序取targets多少個元素來表示一個車牌即標籤:

(7, 7, 8, 8)

我們劃分targets後得到如下標籤:

18, 45, 33, 37, 40, 49, 63  -->> 車牌 “湘E269JY”
4, 54, 51, 34, 53, 37, 38   -->> 車牌 “冀PL3N67”
22, 56, 37, 38,33, 39, 34, 46  -->> 車牌 “川R67283F”
2, 41, 44, 37, 39, 35, 33, 40  -->> 車牌 “津AD68429”

target_lengths元素數量的不同則表示了標籤可變長。

4)打印input_lengths如下,由於輸出序列長度T已經設定爲18,因此其元素均是固定相同的:

(18, 18, 18, 18)

其中,只要模型配置固定了後,log_probs不需要我們組裝再傳送到CTCLoss,但是其餘三個輸入形參均需要我們根據實際數據集及C、T、N的情況進行設定!

三、需要注意的地方

3.1 官方所給的例程如下,但在實際應用中需要將log_probs的detach()去掉,否則無法反向傳播進行訓練;

>>> ctc_loss = nn.CTCLoss()
>>> log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_()
>>> targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
>>> input_lengths = torch.full((16,), 50, dtype=torch.long)
>>> target_lengths = torch.randint(10,30,(16,), dtype=torch.long)
>>> loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
>>> loss.backward()

3.2 blank空白標籤一定要依據空白符在預測總字符集中的位置來設定,否則就會出錯;

3.3 targets建議將其shape設爲(sum(target_lengths)),然後再由target_lengths進行輸入序列長度指定就好了,這是因爲如果設定爲(N, S),則因爲S的標籤長度如果是可變的,那麼我們組裝出來的二維張量的第一維度的長度僅爲min(S)將損失一部分標籤值(多維數組每行的長度必須一致),這就導致模型無法預測較長長度的標籤;

3.4 輸出序列長度T儘量在模型設計時就要考慮到模型需要預測的最長序列,如需要預測的最長序列其長度爲I,則理論上T應大於等於2I+1,這是因爲CTCLoss假設在最壞情況下每個真實標籤前後都至少有一個空白標籤進行隔開以區分重複項;

3.5 輸出的log_probs除了進行log_softmax()處理再送入CTCLoss外,還必須要調整其維度順序,確保其shape爲(T, N, C)!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章