文字檢測算法CTPN網絡模型及tensorflow版本代碼介紹

åºæ¯æå­æ£æµâCTPNåçä¸å®ç°

網絡結構:

1、基礎網絡時VGG16,在conv5_3卷積之後的特徵圖上進行後續處理

2、在conv5_3的特徵圖之上,使用3x3的卷積核進行滑窗處理,這就是Faster R-CNN中使用的RPN網絡

3、然後以特徵圖的行爲單位,將每行內容分別輸入到雙向LSTM循環網絡中,將雙向循環網絡的輸出結果進行concat連接,每個LSTM輸出的結果是128維向量,所以每個位置的輸出結果是256維的向量,得到的特徵圖大小就是 H x W x 256

4、在特徵圖上的每個(h, w)的位置上,後面連接一個全連接層FC,其實也可以使用1x1的卷積層(Faster R-CNN中使用的就是1x1的卷積層),分別得到3個分支,第一個分支表示基於RPN網絡在每個(h, w)位置上生成的k個anchor而預測得到的k個推薦框proposal是前景還是背景的得分,可以使用softmax實現,第二個分支表示基於RPN網絡在每個(h, w)位置上生成的k個anchor而預測得到的k個推薦框proposal的高度以及proposal高度的中心位置,因爲proposal的寬度是固定的16。

關於RPN網絡的作用在我之前寫的Faster R-CNN中RPN網絡總結中有描述,簡單總結一下就是,在RPN網絡處理階段,anchor的大小,座標,以及anchor所屬的label(anchor區域中是否包含目標)都是可以提前確定的,然後基於FC全連接層輸出的預測結果計算損失函數,以anchor爲基準點,使得訓練後的模型預測結果更接近GT的真實結果。anchor的具體作用是爲了儘量覆蓋更多位置、更多尺度的目標,然後以anchor的位置和label爲基準,訓練模型參數,使得模型預測結果更接近真實值GT。

這裏需要說明一下,有些人可能會疑惑,論文裏明明白白的寫了使用了全連接層FC,大家都知道全連接層要求輸入的大小是固定的,而論文裏面又說了CTPN實質上一個全卷積網絡,可以處理任意大小的圖片。問題的根本就在於雖然使用了全連接網絡,但是全連接網絡是作用於特徵圖的通道維度上的,特徵圖的通道維度是確定的,與輸入圖片的大小無關,所以CTPN可以處理任意大小的圖片,其本質上也就是全卷積網絡了。

tensorflow版本代碼介紹,本篇代碼來自於開源項目 text-detection-ctpn,遺憾的是項目中的百度網盤地址已經失效,沒有獲取到訓練數據,訓練過程還沒跑通。

推理預測階段:

1)、demo.py代碼文件關鍵函數

resize_image對輸入的圖片進行縮放,縮放後的圖片短邊爲600,並保持圖片的長寬比不變。

2)、model_train.py代碼文件關鍵函數

從model函數中可以明確的看到,模型在VGG16的基礎上先進行了3x3的卷積操作,也就是RPN網絡中的滑窗操作,然後將滑窗結果輸入到雙向LSTM。下面看一下雙向LSTM中的操作。

在雙向LSTM函數中可以看到,net的形狀是[N * H, W, C],其中N * H就是輸入到LSTM的batch_size大小,W是輸入序列的長度,C是每一步輸入到LSTM中的數據維度,可以看到,每一步都是沿着W的方向,將通道C上的向量輸入到LSTM中,然後將兩個LSTM的輸出結果進行連接。按照上面的描述,經過LSTM之後的輸出結果要輸入到FC全連接層,用於預測proposal的座標和得分,也就是lstm_fc函數的功能。

可以看到在lstm_fc函數中,分別在每個(h, w)位置的通道向量上進行FC操作,這也就是前面說的,CTPN中雖然使用了FC全連接層,但本質上仍是一個全卷積網絡,因爲這裏的FC只和通道的大小有關係,與特徵圖的H和W無關。

在model函數的最後將lstm_fc函數的輸出結果使用softmax函數轉換爲對應的前景和背景的概率。

3)、proposal_layer.py代碼文件關鍵函數

proposal_layer函數用來計算基於RPN的預測結果以及生成的anchor座標,計算RPN預測結果在最終圖像上的座標,然後對生成的proposal進行裁剪、過濾、NMS得到最終的proposal。

模型訓練階段:

前面粗略介紹了模型的正向推理過程,下面講一下模型的訓練過程。

1)、anchor_target_layer.py代碼文件關鍵函數

anchor_target_layer函數的主要功能是在特徵圖上的每個位置生成10個anchor,並基於訓練數據上的GT座標位置計算與anchor IoU得到anchor的label。然後對生成的anchor進行篩選、過濾,最終保留128個positive anchor和128個negative anchor,如果positive anchor的數量太少,就是用negative anchor進行填充。

loss函數用來計算模型的損失函數:

從上面代碼可以看到,對於分類分支使用交叉熵損失,對於座標預測分支使用smooth_l1損失,tensorflow版本的代碼中沒有計算第三個分支。其中,對於分類分支的損失計算的是RPN預測的概率和anchor真是標籤之間的損失,迴歸分支的損失計算的是RPN預測結果與anchor座標的差異和GT與anchor座標差異之間的損失。分類和迴歸損失的計算與Faster R-CNN中一致。

以上簡單介紹了CTPN文字檢測算法的模型,以及tensorflow版本的代碼實現。要想深入理解算法的具體內容,還是要實地研究一下代碼的具體實現。以上僅是個人的理解與觀點,如有錯誤,歡迎指正。

 

道阻且長,加油吧!少年。

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章