CTPN訓練集準備

  1. 去down vgg_16.ckpt預訓練模型
  2. 準備一堆你需要訓練的圖片,使用labelme進行標註,得到一堆json文件
    json文件大致如下:
{
  "flags": {},
  "shapes": [
    {
      "label": "str",
      "line_color": null,
      "fill_color": null,
      "points": [
        [
          91,
          183
        ],
        [
          178,
          183
        ],
        [
          178,
          231
        ],
        [
          91,
          231
        ]
      ]
    }
    ],
  "lineColor": [
    0,
    255,
    0,
    128
  ],
  "fillColor": [
    255,
    0,
    0,
    128
  ],
  "imagePath": "..\\img\\000_004.jpg",
  "imageData": "/9..."
  }
  1. 使用如下代碼將上一步生成的json文件轉換成split_label.py所需要的格式
def json2txt():
    json_dir = r''
    output_dir = r''
    for json_name in os.listdir(json_dir):
        point_list = []
        json_file = os.path.join(json_dir, json_name)
        with open(json_file, 'r') as rf:
            info = json.load(rf)
	        for item in info['shapes']:
	            for point in item['points']:
	                point_list.append(point)
	        point_arry = np.array(point_list)
	        point_arry = point_arry.reshape((-1, 8))
	        output_path = os.path.join(output_dir, 'gt_' + json_name.split('.')[0] + '.txt')
	        np.savetxt(output_path, point_arry, fmt='%s', delimiter=',')

轉換後的格式:
每一行爲一個矩形框的4個點

91,183,178,183,178,231,91,231
191,183,401,183,401,232,191,232
503,185,605,185,605,234,503,234
616,192,747,192,747,232,616,232
769,196,832,196,832,232,769,232
847,194,925,194,925,242,847,242
936,194,1071,194,1071,240,936,240
92,234,274,234,274,272,92,272
96,287,412,287,412,327,96,327
94,338,454,338,454,391,94,391
96,400,416,400,416,443,96,443
94,452,345,452,345,498,94,498
92,511,325,511,325,563,92,563
87,569,341,569,341,616,87,616
92,627,372,627,372,676,92,676
94,691,431,691,431,731,94,731
87,751,312,751,312,776,87,776
87,776,798,776,798,834,87,834
  1. 修改split_label.py中的目錄路徑,DATA_FOLDER路徑下有包含你的圖片文件夾"image"和上一步生成的標籤文件夾"label",OUTPUT爲你的輸出目錄。
    PS: split_label.py和utils.py主要來自於 https://github.com/eragonruan/text-detection-ctpn
# split_label.py
import os
import sys

import cv2 as cv
import numpy as np
from tqdm import tqdm

sys.path.append(os.getcwd())
from utils import orderConvex, shrink_poly

DATA_FOLDER = r"E:\code\OCR\data"
OUTPUT = r"E:\code\OCR\data\output"
MAX_LEN = 1200
MIN_LEN = 600

im_fns = os.listdir(os.path.join(DATA_FOLDER, "image"))
im_fns.sort()

if not os.path.exists(os.path.join(OUTPUT, "image")):
    os.makedirs(os.path.join(OUTPUT, "image"))
if not os.path.exists(os.path.join(OUTPUT, "label")):
    os.makedirs(os.path.join(OUTPUT, "label"))

for im_fn in tqdm(im_fns):
    try:
        _, fn = os.path.split(im_fn)
        bfn, ext = os.path.splitext(fn)
        if ext.lower() not in ['.jpg', '.png']:
            continue

        gt_path = os.path.join(DATA_FOLDER, "label", 'gt_' + bfn + '.txt')
        img_path = os.path.join(DATA_FOLDER, "image", im_fn)

        img = cv.imread(img_path)
        img_size = img.shape
        im_size_min = np.min(img_size[0:2])
        im_size_max = np.max(img_size[0:2])

        im_scale = float(600) / float(im_size_min)
        if np.round(im_scale * im_size_max) > 1200:
            im_scale = float(1200) / float(im_size_max)
        new_h = int(img_size[0] * im_scale)
        new_w = int(img_size[1] * im_scale)

        new_h = new_h if new_h // 16 == 0 else (new_h // 16 + 1) * 16
        new_w = new_w if new_w // 16 == 0 else (new_w // 16 + 1) * 16

        re_im = cv.resize(img, (new_w, new_h), interpolation=cv.INTER_LINEAR)
        re_size = re_im.shape

        polys = []
        with open(gt_path, 'r') as f:
            lines = f.readlines()
        for line in lines:
            splitted_line = line.strip().lower().split(',')
            x1, y1, x2, y2, x3, y3, x4, y4 = map(float, splitted_line[:8])
            poly = np.array([x1, y1, x2, y2, x3, y3, x4, y4]).reshape([4, 2])
            poly[:, 0] = poly[:, 0] / img_size[1] * re_size[1]
            poly[:, 1] = poly[:, 1] / img_size[0] * re_size[0]
            poly = orderConvex(poly)
            polys.append(poly)

            # cv.polylines(re_im, [poly.astype(np.int32).reshape((-1, 1, 2))], True,color=(0, 255, 0), thickness=2)

        res_polys = []
        for poly in polys:
            # delete polys with width less than 10 pixel
            if np.linalg.norm(poly[0] - poly[1]) < 10 or np.linalg.norm(poly[3] - poly[0]) < 10:
                continue

            res = shrink_poly(poly)
            # for p in res:
            #    cv.polylines(re_im, [p.astype(np.int32).reshape((-1, 1, 2))], True, color=(0, 255, 0), thickness=1)

            res = res.reshape([-1, 4, 2])
            for r in res:
                x_min = np.min(r[:, 0])
                y_min = np.min(r[:, 1])
                x_max = np.max(r[:, 0])
                y_max = np.max(r[:, 1])

                res_polys.append([x_min, y_min, x_max, y_max])

        cv.imwrite(os.path.join(OUTPUT, "image", fn), re_im)
        with open(os.path.join(OUTPUT, "label", bfn) + ".txt", "w") as f:
            for p in res_polys:
                line = ",".join(str(p[i]) for i in range(4))
                f.writelines(line + '\n')
                # for p in res_polys:
                #    cv.rectangle(re_im,(p[0],p[1]),(p[2],p[3]),color=(0,0,255),thickness=1)

                # cv.imshow("demo",re_im)
                # cv.waitKey(0)
    except:
        print("Error processing {}".format(im_fn))

裏面用到了utils.py裏的功能函數,這裏需要安裝一個shapely的包,直接去https://www.lfd.uci.edu/~gohlke/pythonlibs/#shapely 下對應的版本安裝即可

# utils.py
import numpy as np
from shapely.geometry import Polygon


def pickTopLeft(poly):
    idx = np.argsort(poly[:, 0])
    if poly[idx[0], 1] < poly[idx[1], 1]:
        s = idx[0]
    else:
        s = idx[1]

    return poly[(s, (s + 1) % 4, (s + 2) % 4, (s + 3) % 4), :]


def orderConvex(p):
    points = Polygon(p).convex_hull
    points = np.array(points.exterior.coords)[:4]
    points = points[::-1]
    points = pickTopLeft(points)
    points = np.array(points).reshape([4, 2])
    return points


def shrink_poly(poly, r=16):
    # y = kx + b
    x_min = int(np.min(poly[:, 0]))
    x_max = int(np.max(poly[:, 0]))

    k1 = (poly[1][1] - poly[0][1]) / (poly[1][0] - poly[0][0])
    b1 = poly[0][1] - k1 * poly[0][0]

    k2 = (poly[2][1] - poly[3][1]) / (poly[2][0] - poly[3][0])
    b2 = poly[3][1] - k2 * poly[3][0]

    res = []

    start = int((x_min // 16 + 1) * 16)
    end = int((x_max // 16) * 16)

    p = x_min
    res.append([p, int(k1 * p + b1),
                start - 1, int(k1 * (p + 15) + b1),
                start - 1, int(k2 * (p + 15) + b2),
                p, int(k2 * p + b2)])

    for p in range(start, end + 1, r):
        res.append([p, int(k1 * p + b1),
                    (p + 15), int(k1 * (p + 15) + b1),
                    (p + 15), int(k2 * (p + 15) + b2),
                    p, int(k2 * p + b2)])
    return np.array(res, dtype=np.int).reshape([-1, 8])

這裏會把圖片resize到608*864,然後生成對應的文本框。都在你指定的OUTPUT文件夾下。
如果把生成的文本框畫到相應的圖片上就長這樣:

在這裏插入圖片描述
labelme中標註的是這樣:
在這裏插入圖片描述
然後就可以使用OUTPUT下的image和label兩個文件夾去訓練咯~

發佈了63 篇原創文章 · 獲贊 55 · 訪問量 10萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章