本週工作進展
創新實訓正式開始,第一週主要是熟悉 keras 的使用,並嘗試復現文本檢測模型 ctpn。
2019.3.24 昨天晚上繼續做的時候發現我把模型理解錯了,下面的內容如果有人看見的話就當我在胡說八道好了(/ω\*)……
詳細工作內容
完成模型大題框架搭建,正在設計損失函數與數據集輸入方法。
如圖,模型由以下部分組成:
- VGG16 網絡,取到 conv5 的第三層;
- 在Conv5的feature map的每個位置上取3*3*C的窗口的特徵,輸入雙向 LSTM;
- 全連接層;
- 最後輸出三個分支,分別是 anchor 的縱座標與高度,anchor 是否爲文字的分數,以及邊緣提純的偏移量。
def ctpn_model():
date = np.random.rand(1, 600, 900, 3).astype(np.float32)
input_layer = vgg16_no_tail()(date)
# unfold 不會弄,先用卷積代替
x = keras.layers.Convolution2D(
512, (3, 3),
activation='relu',
padding='same',
name='cnn2rnn')(input_layer)
# 下面雙向 lstm 出了點問題,先將輸出變形
x = keras.layers.Reshape((x.shape[1], -1))(x)
# 如果 shape0 是 batch,先假設 shape1 是 h
x = keras.layers.Bidirectional(keras.layers.LSTM(128))(x)
x = keras.layers.Dense(512)(x)
vertical_coordinate = keras.layers.Dense(20)(x)
score = keras.layers.Dense(20)(x)
side_refinement = keras.layers.Dense(10)(x)
# vgg16 取 conv5 的第三層
def vgg16_no_tail():
vgg = keras.applications.VGG16(weights=None)
vgg_no_tail = keras.Model(
inputs=vgg.input,
outputs=vgg.get_layer("block5_conv3").output)
return vgg_no_tail
下一步計劃
- 模型中 cnn 轉向 rnn 的中間步驟需要提取 3*3 的矩陣並拼接,目前暫時以一個卷積層代替,視最終運行效果決定是否修改。
- 完成損失函數和數據集輸入方法的設計。