OCR -- 文本檢測 - 訓練DB文字檢測模型

百度飛槳(PaddlePaddle) - PP-OCRv3 文字檢測識別系統 預測部署簡介與總覽
百度飛槳(PaddlePaddle) - PP-OCRv3 文字檢測識別系統 Paddle Inference 模型推理(離線部署)
百度飛槳(PaddlePaddle) - PP-OCRv3 文字檢測識別系統 基於 Paddle Serving快速使用(服務化部署 - CentOS)
百度飛槳(PaddlePaddle) - PP-OCRv3 文字檢測識別系統 基於 Paddle Serving快速使用(服務化部署 - Docker)

PaddleOCR提供DB文本檢測算法,支持MobileNetV3、ResNet50_vd兩種骨幹網絡,可以根據需要選擇相應的配置文件,啓動訓練。

本節以icdar15數據集、MobileNetV3作爲骨幹網絡的DB檢測模型(即超輕量模型使用的配置)爲例,介紹如何完成PaddleOCR中文字檢測模型的訓練、評估與測試。

3.1 數據準備

本次實驗選取了場景文本檢測和識別(Scene Text Detection and Recognition)任務最知名和常用的數據集ICDAR2015。icdar2015數據集的示意圖如下圖所示:


圖 icdar2015數據集示意圖


該項目中已經下載了icdar2015數據集,存放在 /home/aistudio/data/data96799 中,可以運行如下指令完成數據集解壓,或者從鏈接中自行下載
image

~/train_data/icdar2015/text_localization 
  └─ icdar_c4_train_imgs/         icdar數據集的訓練數據
  └─ ch4_test_images/             icdar數據集的測試數據
  └─ train_icdar2015_label.txt    icdar數據集的訓練標註
  └─ test_icdar2015_label.txt     icdar數據集的測試標註

提供的標註文件格式爲:

" 圖像文件名                    json.dumps編碼的圖像標註信息"
ch4_test_images/img_61.jpg    [{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]], ...}]

json.dumps編碼前的圖像標註信息是包含多個字典的list,字典中的points表示文本框的四個點的座標(x, y),從左上角的點開始順時針排列。 transcription中的字段表示當前文本框的文字,在文本檢測任務中並不需要這個信息。 如果您想在其他數據集上訓練PaddleOCR,可以按照上述形式構建標註文件。

如果"transcription"字段的文字爲'*'或者'###',表示對應的標註可以被忽略掉,因此,如果沒有文字標籤,可以將transcription字段設置爲空字符串。

3.2 數據預處理

訓練時對輸入圖片的格式、大小有一定的要求,同時,還需要根據標註信息獲取閾值圖以及概率圖的真實標籤。所以,在數據輸入模型前,需要對數據進行預處理操作,使得圖片和標籤滿足網絡訓練和預測的需要。另外,爲了擴大訓練數據集、抑制過擬合,提升模型的泛化能力,還需要使用了幾種基礎的數據增廣方法。

本實驗的數據預處理共包括如下方法:

  • 圖像解碼:將圖像轉爲Numpy格式;
  • 標籤解碼:解析txt文件中的標籤信息,並按統一格式進行保存;
  • 基礎數據增廣:包括:隨機水平翻轉、隨機旋轉,隨機縮放,隨機裁剪等;
  • 獲取閾值圖標籤:使用擴張的方式獲取算法訓練需要的閾值圖標籤;
  • 獲取概率圖標籤:使用收縮的方式獲取算法訓練需要的概率圖標籤;
  • 歸一化:通過規範化手段,把神經網絡每層中任意神經元的輸入值分佈改變成均值爲0,方差爲1的標準正太分佈,使得最優解的尋優過程明顯會變得平緩,訓練過程更容易收斂;
  • 通道變換:圖像的數據格式爲[H, W, C](即高度、寬度和通道數),而神經網絡使用的訓練數據的格式爲[C, H, W],因此需要對圖像數據重新排列,例如[224, 224, 3]變爲[3, 224, 224];

圖像解碼

從訓練數據的標註中讀取圖像,演示DecodeImage類的使用方式。
源碼位置:\ppocr\data\imaug\operators.py

import os
import matplotlib.pyplot as plt
from paddleocr.ppocr.data.imaug.operators import DecodeImage
 

label_path = "../train_data/icdar2015/text_localization/train_icdar2015_label.txt"
img_dir = "../train_data/icdar2015/text_localization/"

# 1. 讀取訓練標籤的第一條數據
f = open(label_path, "r")
lines = f.readlines()

# 2. 取第一條數據
line = lines[0]

print("The first data in train_icdar2015_label.txt is as follows.\n", line)
img_name, gt_label = line.strip().split("\t")

# 3. 讀取圖像
image = open(os.path.join(img_dir, img_name), 'rb').read()
data = {'image': image, 'label': gt_label}

# 4. 聲明DecodeImage類,解碼圖像
decode_image = DecodeImage(img_mode='RGB', channel_first=False)
data = decode_image(data)

# 5. 打印解碼後圖像的shape,並可視化圖像
print("The shape of decoded image is ", data['image'].shape)

plt.figure(figsize=(10, 10))
plt.imshow(data['image'])
src_img = data['image']
plt.show()

image

標籤解碼

解析txt文件中的標籤信息,並按統一格式進行保存;
源碼位置:ppocr/data/imaug/label_ops.py

import os
from paddleocr.ppocr.data.imaug.label_ops  import DetLabelEncode

label_path = "../train_data/icdar2015/text_localization/train_icdar2015_label.txt"
img_dir = "../train_data/icdar2015/text_localization/"

# 1. 讀取訓練標籤的第一條數據
f = open(label_path, "r")
lines = f.readlines()

# 2. 取第一條數據
line = lines[0]

print("The first data in train_icdar2015_label.txt is as follows.\n", line)
img_name, gt_label = line.strip().split("\t")

# 3. 讀取圖像
image = open(os.path.join(img_dir, img_name), 'rb').read()
data = {'image': image, 'label': gt_label}

# 1. 聲明標籤解碼的類
decode_label = DetLabelEncode()
# 2. 打印解碼前的標籤
print("The label before decode are: ", data['label'])
data = decode_label(data)
print("\n")

# 4. 打印解碼後的標籤
print("The polygon after decode are: ", data['polys'])
print("The text after decode are: ", data['texts'])

基礎數據增廣

數據增廣是提高模型訓練精度,增加模型泛化性的常用方法,文本檢測常用的數據增廣包括隨機水平翻轉、隨機旋轉、隨機縮放以及隨機裁剪等等。

隨機水平翻轉、隨機旋轉、隨機縮放的代碼實現參考代碼。隨機裁剪的數據增廣代碼實現參考代碼

獲取閾值圖標籤

使用擴張的方式獲取算法訓練需要的閾值圖標籤;
源碼位置:ppocr/data/imaug/make_border_map.py

# 從PaddleOCR中import MakeBorderMap
from ppocr.data.imaug.make_border_map import MakeBorderMap

# 1. 聲明MakeBorderMap函數
generate_text_border = MakeBorderMap()

# 2. 根據解碼後的輸入數據計算bordermap信息
data = generate_text_border(data)

# 3. 閾值圖可視化
plt.figure(figsize=(10, 10))
plt.imshow(src_img)

text_border_map = data['threshold_map']
plt.figure(figsize=(10, 10))
plt.imshow(text_border_map)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章