【OCR】文字檢測:傳統算法、CTPN、EAST

我的east和ctpn速度差不多,east正確率高4%

http://xiaofengshi.com/2019/01/23/深度學習-TextDetection/

https://codeload.github.com/GlassyWing/text-detection-ocr/zip/master

1、傳統算法

import cv2
import numpy as np

# 讀取圖片
imagePath = 'asset/0015.jpg'
img = cv2.imread(imagePath)
def get_box(img):
    # 轉化成灰度圖
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    # 利用Sobel邊緣檢測生成二值圖
    sobel = cv2.Sobel(gray, cv2.CV_8U, 1, 0, ksize=3)
    # 二值化
    ret, binary = cv2.threshold(sobel, 0, 255, cv2.THRESH_OTSU + cv2.THRESH_BINARY)

    # 膨脹、腐蝕
    element1 = cv2.getStructuringElement(cv2.MORPH_RECT, (30, 9))
    element2 = cv2.getStructuringElement(cv2.MORPH_RECT, (24, 6))

    # 膨脹一次,讓輪廓突出
    dilation = cv2.dilate(binary, element2, iterations=1)

    # 腐蝕一次,去掉細節
    erosion = cv2.erode(dilation, element1, iterations=1)

    # 再次膨脹,讓輪廓明顯一些
    dilation2 = cv2.dilate(erosion, element2, iterations=2)

    #  查找輪廓和篩選文字區域
    region = []
    contours, hierarchy = cv2.findContours(dilation2, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    for i in range(len(contours)):
        cnt = contours[i]

        # 計算輪廓面積,並篩選掉面積小的
        area = cv2.contourArea(cnt)
        if (area < 1000):
            continue

        # 找到最小的矩形
        rect = cv2.minAreaRect(cnt)
        print("rect is: ")
        print(rect)

        # box是四個點的座標
        box = cv2.boxPoints(rect)
        box = np.int0(box)

        # 計算高和寬
        height = abs(box[0][1] - box[2][1])
        width = abs(box[0][0] - box[2][0])

        # 根據文字特徵,篩選那些太細的矩形,留下扁的
        if (height > width * 1.3):
            continue

        region.append(box)
        print('box is:',box)
    return region


from math import *

def calcuate_angle(lines_h,length):
    angle_all = []
    for x in range(0,min(10,len(lines_h))):
    #    for x in range(0, len(lines_h)):
        for x1,y1,x2,y2 in lines_h[x]:
            # print('(x2-x1)*(x2-x1)+(y2-y1)*(y2-y1):',(x2-x1)*(x2-x1)+(y2-y1)*(y2-y1))
            if(x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)<length:
                continue
        #                cv2.line(img_color,(x1,y1),(x2,y2),(0,255,0),10)
            if(y2==y1):
                angle_line = 90
            else:
                angle_line = atan((x2-x1)/(y2-y1))*180/pi
            if angle_line>45 :
                angle_line = angle_line-90
            if angle_line<-45:
                angle_line = angle_line+90
            angle_all.append(angle_line)

    angle_all.sort()
    #    angle_all_sort = angle_all[:int(9*len(angle_all)/10)]
    #            print(angle_all_sort)
    angle = -angle_all[int(len(angle_all)/2)]
    #            angle = angle_all/len(lines_h)
    return angle

'''旋轉圖像'''
def RotateDegree(img,degree):
    #degree左轉
#    img = cv2.imread(img)
    height, width = img.shape[:2]
    heightNew = int(width * fabs(sin(radians(degree))) + height * fabs(cos(radians(degree))))
    widthNew = int(height * fabs(sin(radians(degree))) + width * fabs(cos(radians(degree))))

    matRotation = cv2.getRotationMatrix2D((width / 2, height / 2), degree, 1)

    matRotation[0, 2] += (widthNew - width) / 2
    matRotation[1, 2] += (heightNew - height) / 2
    imgRotation = cv2.warpAffine(img, matRotation, (widthNew, heightNew), borderValue=(255, 255, 255))
    return imgRotation


if __name__ == '__main__':


    # 繪製輪廓
    # region = get_box(img)
    # print(region)
    image = cv2.imread('./asset/0015.jpg')
    lines_h = []
    text_recs = get_box(image)
    print(len(text_recs))
    print(text_recs[1][0][0])
    for i in range(min(10, len(text_recs))):
        print('---:',(text_recs[i][0][0], text_recs[i][0][1]), (text_recs[i][1][0], text_recs[i][1][1]))
        lines_h.append([[text_recs[i][0][0], text_recs[i][0][1],text_recs[i][1][0], text_recs[i][1][1]]])
    angle = calcuate_angle(lines_h, 100)
    print(angle)
    new_img = RotateDegree(image,angle)
    cv2.imwrite('./asset/0015_rotate.jpg',new_img)
    # for box in region:
    #     cv2.drawContours(img, [box], 0, (0, 255, 0), 2)
    # cv2.imwrite('./asset/output_1.jpg',img)
    # cv2.imshow('img', img)
    # cv2.waitKey(0)
    # cv2.destroyAllWindows()

2、CTPN

結合Web版需要初始化一次模型,並且不能和其他模型衝突,我改了下:

class CTPN:

    def __init__(self, lr=0.00001, image_channels=3, vgg_trainable=True, weight_path=None, num_gpu=1):
        self.image_channels = image_channels
        self.image_shape = (None, None, image_channels)
        self.vgg_trainable = vgg_trainable
        self.num_gpu = num_gpu
        self.lr = lr
        self.model, self.parallel_model, self.predict_model = self.__build_model()
        if weight_path is not None:
            self.model.load_weights(weight_path)

    def __build_model(self):
        base_model = VGG16(weights=None, include_top=False, input_shape=self.image_shape)
        base_model.load_weights(vgg_weights_path)
        if self.vgg_trainable:
            base_model.trainable = True
        else:
            base_model.trainable = False

        input = base_model.input
        sub_output = base_model.get_layer('block5_conv3').output

        x = Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu',
                   name='rpn_conv1')(sub_output)

        x1 = Lambda(_reshape, output_shape=(None, 512))(x)

        x2 = Bidirectional(GRU(128, return_sequences=True), name='blstm')(x1)

        x3 = Lambda(_reshape2, output_shape=(None, None, 256))([x2, x])
        x3 = Conv2D(512, (1, 1), padding='same', activation='relu', name='lstm_fc')(x3)

        cls = Conv2D(10 * 2, (1, 1), padding='same', activation='linear', name='rpn_class_origin')(x3)
        regr = Conv2D(10 * 2, (1, 1), padding='same', activation='linear', name='rpn_regress_origin')(x3)

        cls = Lambda(_reshape3, output_shape=(None, 2), name='rpn_class')(cls)
        cls_prod = Activation('softmax', name='rpn_cls_softmax')(cls)

        regr = Lambda(_reshape3, output_shape=(None, 2), name='rpn_regress')(regr)

        predict_model = Model(input, [cls, regr, cls_prod])

        train_model = Model(input, [cls, regr])

        parallel_model = train_model
        if self.num_gpu > 1:
            parallel_model = multi_gpu_model(train_model, gpus=self.num_gpu)

        adam = Adam(self.lr)
        parallel_model.compile(optimizer=adam,
                               loss={'rpn_regress': _rpn_loss_regr, 'rpn_class': _rpn_loss_cls},
                               loss_weights={'rpn_regress': 1.0, 'rpn_class': 1.0})

        return train_model, parallel_model, predict_model

    def train(self, train_data_generator, epochs, **kwargs):
        self.parallel_model.fit_generator(train_data_generator, epochs=epochs, **kwargs)

    def predict(self, image, output_path=None, mode=1):

        if type(image) == str:
            img = cv2.imdecode(np.fromfile(image, dtype=np.uint8), cv2.IMREAD_COLOR)
        else:
            img = image
        h, w, c = img.shape

        # image size length must be greater than or equals 16 x 16,
        # because of the image will be reduced by 16 times.
        if h < 16 or w < 16:
            transform_w = max(16, w)
            transform_h = max(16, h)
            transform_img = np.ones(shape=(transform_h, transform_w, 3), dtype='uint8') * 255
            transform_img[:h, :w, :] = img
            h = transform_h
            w = transform_w
            img = transform_img

        # zero-center by mean pixel
        m_img = img - utils.IMAGE_MEAN
        m_img = np.expand_dims(m_img, axis=0)

        cls, regr, cls_prod = self.predict_model.predict_on_batch(m_img)
        anchor = utils.gen_anchor((int(h / 16), int(w / 16)), 16)

        bbox = utils.bbox_transfor_inv(anchor, regr)
        bbox = utils.clip_box(bbox, [h, w])

        # score > 0.7
        fg = np.where(cls_prod[0, :, 1] > utils.IOU_SELECT)[0]
        select_anchor = bbox[fg, :]
        select_score = cls_prod[0, fg, 1]
        select_anchor = select_anchor.astype('int32')

        # filter size
        keep_index = utils.filter_bbox(select_anchor, 16)

        # nsm
        select_anchor = select_anchor[keep_index]
        select_score = select_score[keep_index]
        select_score = np.reshape(select_score, (select_score.shape[0], 1))
        nmsbox = np.hstack((select_anchor, select_score))
        keep = utils.nms(nmsbox, 1 - utils.IOU_SELECT)
        select_anchor = select_anchor[keep]
        select_score = select_score[keep]

        # text line
        textConn = TextProposalConnectorOriented()
        text = textConn.get_text_lines(select_anchor, select_score, [h, w])

        text = text.astype('int32')

        if mode == 1:
            for i in text:
                draw_rect(i, img)

            plt.imshow(img)
            plt.show()
            if output_path is not None:
                cv2.imwrite(output_path, img)
        elif mode == 2:
            return text, img

    def config(self):
        return {
            "image_channels": self.image_channels,
            "vgg_trainable": self.vgg_trainable,
            "lr": self.lr
        }

    @staticmethod
    def save_config(obj, config_path):
        with open(config_path, "w+") as outfile:
            json.dump(obj.config(), outfile)

    @staticmethod
    def load_config(config_path):
        with open(config_path, "r") as infile:
            return dict(json.load(infile))

不過檢測框好像不對

 

3、EAST

https://github.com/huoyijie/AdvancedEAST

初始化遇到的Bug:https://blog.csdn.net/Maisie_Nan/article/details/103121134

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章