轉自:https://www.cnblogs.com/YSPXIZHEN/p/11343426.html
GitHub:https://github.com/pengcao/chinese_ocr https://github.com/xiaofengShi/CHINESE-OCR
|-angle 基於VGG分類模型的文字方向檢測預測 |-bash 環境安裝 |----setup-python3.sh 安裝python3環境 |----setup-python3-cpu.sh 安裝CPU環境 |----setup-python3-gpu.sh 安裝CPU環境 |-crnn |-ctpn 基於CTPN模型的文本區域檢測模型訓練及預測 |-model |----modelAngle.h5 文字方向檢測VGG模型 |----my_model_keras.h5 文本識別CRNN模型 |-ocr 基於CRNN的文本識別模型預測 |-result 預測圖片 |-test 測試圖片 |-train 基於CRNN的文本識別模型訓練
環境要求:
python3.6 tensorflow1.7-cpu/gpu graphviz pydot (py)torch torchvision
- 卸載舊版本的pytorch和torchvision
pip uninstall torchvision pip uninstall torch
- 安裝pytorch
1)Anaconda搜索torch
2)選擇標記處點開
3)Anaconda Prompt - conda install -c peterjc123 pytorch
- 安裝torchvision
conda install torchvision -c pytorch # TorchVision requires PyTorch 1.1 or newer
離線安裝pytorch 1).whl安裝
從pytorch官網https://pytorch.org/previous-versions/下載合適版本torch及torchvision的whl
# 直接對whl文件進行編譯即可 pip install torch-0.4.0-cp36-cp36m-linux_x86_64.whl pip install torchvision-0.2.1-py2.py3-none-any.whl
2).tar.gz安裝
下載對應版本的.tar.gz文件,並解壓
# 進入解壓目錄,執行安裝命令 python setup.py install
離線安裝GCC(Tensorflow部分第三方模塊需要GCC進行編譯,所以在安裝第三方的依賴包之前先安裝GCC)
從https://pkgs.org/download/gcc下載gcc-4.8.5-28.el7_5.1.x86_64.rpm版本,並且在require部分下載所需要的rpm文件(根據報錯缺失的rpm下載)
rpm -ivh gcc-4.8.5-28.el7_5.1.x86_64.rpm # 如果已經有舊的版本會報conflicts with錯誤 rpm -ivh gcc-4.8.5-28.el7_5.1.x86_64.rpm --force
模型
- 文本方向檢測網絡 - Classify(vgg16)
- 文本區域檢測網絡 - CTPN(CNN+RNN) - 支持CPU、GPU環境,一鍵部署 - 文本檢測訓練Github:https://github.com/eragonruan/text-detection-ctpn
- EndToEnd文本識別網絡 - CRNN(CNN+GRU/LSTM+CTC)
文本方向檢測
訓練:基於圖像分類模型 - VGG16分類模型,訓練0、90、180、270度檢測的分類模型(
angle/predict.py
),
訓練圖片8000張,準確率88.23%
模型:https://pan.baidu.com/s/1Sqbnoeh1lCMmtp64XBaK9w(n2v4)
文本區域檢測
基於深度學習的文本區域檢測方法:http://xiaofengshi.com/2019/01/23/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0-TextDetection/
CTPN(CNN+RNN)網路結構:
CTPN是一種基於目標檢測方法的文本檢測模型,在repo的CTPN中anchor的設置爲固定寬度,高度不同,相關代碼如下:
def generate_anchors(base_size=16, ratios=[0.5, 1, 2], scales=2 ** np.arange(3, 6)): heights = [11, 16, 23, 33, 48, 68, 97, 139, 198, 283] widths = [16] sizes = [] for h in heights: for w in widths: sizes.append((h, w)) return generate_basic_anchors(sizes)
基於這種設置,CTPN只能檢測水平方向的文本,如果想要CTPN可以支持垂直文本檢測,可以在anchor生成函數上進行修改
對CTPN進行訓練:
- 訓練腳本 - 定位到路徑:./ctpn/ctpn/train_net.py
- 預訓練的VGG網絡路徑[VGG_imagenet.npy]:https://pan.baidu.com/s/1jzrcCr0tX6xAiVoolVRyew(a5ze) - 將預訓練權重下載下來,pretrained_model指向該路徑即可
- 模型的預訓練權重[checkpoint]:https://pan.baidu.com/s/1oS6_kqHgmcunkooTAXE8GA(xmjv)
- CTPN數據集[VOCdevkit.zip]:https://pan.baidu.com/s/1NXFmdP_OgRF42xfHXUhBHQ - 下載解壓後將.ctpn/lib/datasets/pascal_voc.py文件中的pascal_voc類中的參數self.devkit_path指向數據集的路徑即可
端到端(EndToEnd)文本識別
OCR識別採用GRU+CTC[CRNN(CNN+GRU/LSTM+CTC)]端到端識別技術,實現不分隔識別不定長文字
CTC - CTC算法原理
CTC是一種解碼機制,在使用CTPN提取到待檢測文本行之後,需要識別提取到的區域內的文本內容,目前廣泛存在兩種解碼機制。
- 一種是Seq2Seq機制,輸入的是圖像,經過卷積編碼之後再使用RNN解碼,爲了提高識別的準確率,一般會加入Attention機制;
- 另一種就是CTC解碼機制,但是對於CTC解碼要滿足一個前提,那就是輸入序列的長度不小於輸出序列的長度。CTC主要用於序列解碼,不需要對序列中的每個元素進行標記,只需要知道輸入序列對應的整個Label是什麼即可,針對OCR項目,也就是輸入一張圖像上面寫着“歡迎來到中國”這幾個字,只需要是這幾個字,而沒必要知道這幾個字在輸入圖像中所在的具體位置,實際上如果知道每個字所在的位置,就是單字符識別了,的確會降低任務的複雜多,但是現實中我們沒有這麼多標記號位置的數據,這個時候CTC就顯得很重要了。
對CRNN進行訓練:
- keras版本:./train/keras_train/train_batch.py(model_path-指向預訓練權重位置,MODEL_PATH-指向模型訓練保存的位置)
- pythorch版本:./train/pytorch-train/crnn_main.py
parser.add_argument( '--crnn', help="path to crnn (to continue training)", default=預訓練權重的路徑) parser.add_argument( '--experiment', help='Where to store samples and models', default=定義的模型訓練的權重保存位置)
模型:
keras模型預訓練權重:https://pan.baidu.com/s/14cTCedz1ESnj0mM9ISm__w(1kb9)
pytorch預訓練權重:https://pan.baidu.com/s/1kAXKudJLqJbEKfGcJUMVtw(9six)
預測測試
運行predict.predict(demo).py:寫入測試圖片的路徑即可
如果想要顯示CTPN的結果,修改文件./ctpn/ctpn/other.py的draw_boxes函數的最後部分,cv2.inwrite('dest_path',img),如此可以得到CTPN檢測的文字區域框以及圖像的OCR識別結果