【創新實訓 第一週】 CTPN 嘗試復現 2019.3.23

本週工作進展

創新實訓正式開始,第一週主要是熟悉 keras 的使用,並嘗試復現文本檢測模型 ctpn。

CTPN 論文地址

2019.3.24 昨天晚上繼續做的時候發現我把模型理解錯了,下面的內容如果有人看見的話就當我在胡說八道好了(/ω\*)……


詳細工作內容

完成模型大題框架搭建,正在設計損失函數與數據集輸入方法。

 

 如圖,模型由以下部分組成:

  1. VGG16 網絡,取到 conv5 的第三層;
  2. 在Conv5的feature map的每個位置上取3*3*C的窗口的特徵,輸入雙向 LSTM;
  3. 全連接層;
  4. 最後輸出三個分支,分別是 anchor 的縱座標與高度,anchor 是否爲文字的分數,以及邊緣提純的偏移量。
def ctpn_model():
    date = np.random.rand(1, 600, 900, 3).astype(np.float32)

    input_layer = vgg16_no_tail()(date)

    # unfold 不會弄,先用卷積代替
    x = keras.layers.Convolution2D(
        512, (3, 3),
        activation='relu',
        padding='same',
        name='cnn2rnn')(input_layer)

    # 下面雙向 lstm 出了點問題,先將輸出變形
    x = keras.layers.Reshape((x.shape[1], -1))(x)

    # 如果 shape0 是 batch,先假設 shape1 是 h
    x = keras.layers.Bidirectional(keras.layers.LSTM(128))(x)

    x = keras.layers.Dense(512)(x)

    vertical_coordinate = keras.layers.Dense(20)(x)
    score = keras.layers.Dense(20)(x)
    side_refinement = keras.layers.Dense(10)(x)


# vgg16 取 conv5 的第三層
def vgg16_no_tail():
    vgg = keras.applications.VGG16(weights=None)
    vgg_no_tail = keras.Model(
        inputs=vgg.input,
        outputs=vgg.get_layer("block5_conv3").output)
    return vgg_no_tail

下一步計劃

  • 模型中 cnn 轉向 rnn 的中間步驟需要提取 3*3 的矩陣並拼接,目前暫時以一個卷積層代替,視最終運行效果決定是否修改。
  • 完成損失函數和數據集輸入方法的設計。

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章