基於attention機制實現 CRNN OCR文字識別

定義網絡結構

實現 BahdanauAttention,其中socre的實現方法爲 perceptron 形式

class BahdanauAttention(tf.keras.Model):
    def __init__(self, units):
        super(BahdanauAttention, self).__init__()
        self.W1 = tf.keras.layers.Dense(units)
        self.W2 = tf.keras.layers.Dense(units)
        self.V = tf.keras.layers.Dense(1)

    def call(self, features, hidden):
        # feature 爲encoder 生成的source編碼矩陣 , hidden爲 i-1 時刻的隱元狀態
        hidden_with_time_axis = tf.expand_dims(hidden, 1)

        # score shape == (batch_size, output length, hidden_size)
        score = self.V(tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis)))

        # we get 1 at the last axis because we are applying score to self.V
        attention_weights = tf.nn.softmax(score, axis=1)

        # context_vector shape after sum == (batch_size, hidden_size)
        context_vector = attention_weights * features
        context_vector = tf.reduce_sum(context_vector, axis=1)

        return context_vector, attention_weights

定義GRU單元

def gru(units):
    if tf.test.is_gpu_available():
        return tf.keras.layers.CuDNNGRU(units,
                                        return_sequences=True,
                                        return_state=True,
                                        recurrent_initializer='glorot_uniform')
    else:
        return tf.keras.layers.GRU(units,
                                   return_sequences=True,
                                   return_state=True,
                                   recurrent_activation='sigmoid',
                                   recurrent_initializer='glorot_uniform')

使用CRNN feature 提取層 和 單層GRU生成編碼器Encoder

class Encoder(tf.keras.Model):
    """
    enc_units: encoder 隱元數量
    batch_sz: batch size
    """
    def __init__(self, enc_units, batch_sz):
        super(Encoder, self).__init__()
        self.batch_sz = batch_sz
        self.enc_units = enc_units
        self.cnn = tf.keras.Sequential([
            tf.keras.layers.Conv2D(64, [3, 3], padding="same", activation='relu'),
            tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2),
            tf.keras.layers.Conv2D(128, [3, 3], padding="same", activation='relu'),
            tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2),
            tf.keras.layers.Conv2D(256, [3, 3], padding="same", activation='relu'),
            tf.keras.layers.Conv2D(256, [3, 3], padding="same", activation='relu'),
            tf.keras.layers.MaxPool2D(pool_size=[2, 1], strides=[2, 1]),
            tf.keras.layers.Conv2D(512, [3, 3], padding="same", activation='relu'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Conv2D(512, [3, 3], padding="same", activation='relu'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.MaxPool2D(pool_size=[2, 1], strides=[2, 1]),
            tf.keras.layers.Conv2D(512, [2, 2], strides=[2, 1], padding="same", activation='relu'),
            tf.keras.layers.Reshape((25, 512))
        ])

        self.gru = gru(self.enc_units)

    def call(self, x):
        x = self.cnn(x)
        output, state = self.gru(x)
        return output, state

    def initialize_hidden_state(self):
        return tf.zeros((self.batch_sz, self.enc_units))

定義 attention 機制和 GRU 單元的解碼器 Decoder

class Decoder(tf.keras.Model):
    def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):
        super(Decoder, self).__init__()
        self.batch_sz = batch_sz
        self.dec_units = dec_units
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        self.gru = gru(self.dec_units)
        self.fc = tf.keras.layers.Dense(vocab_size)
        self.attention = BahdanauAttention(self.dec_units)

    def call(self, x, hidden, enc_output):
        context_vector, attention_weights = self.attention(enc_output, hidden)

        # x shape after passing through embedding == (batch_size, 1, embedding_dim)
        x1 = self.embedding(x)

        # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
        x2 = tf.concat([tf.expand_dims(context_vector, 1), x1], axis=-1)

        # passing the concatenated vector to the GRU
        output, state = self.gru(x2)

        # output shape == (batch_size * 1, hidden_size)
        output = tf.reshape(output, (-1, output.shape[2]))

        # output shape == (batch_size * 1, vocab)
        x = self.fc(output)

        return x, state, attention_weights

準備數據

數據集採用mjsynth.tar.gz,這個數據集有些問題,某些樣本大小寫未分開標註,某些樣本顏色梯度不夠,可以先訓練一個模型後對數據集做篩選,然後再fine tuen.

定義字典

# 將每個詞彙映射爲一個數字
class LanguageIndex():
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.vocab = cfg.CHAR_VECTOR

        self.create_index()

    def create_index(self):
        self.word2idx['<pad>'] = 0
        self.word2idx['<start>'] = 1
        self.word2idx['<end>'] = 2
        self.word2idx[''] = 3
        for index, word in enumerate(self.vocab):
            self.word2idx[word] = index + 4

        for word, index in self.word2idx.items():
            self.idx2word[index] = word

處理 label 爲 <start> .. <end> 格式

root = "../mnt/ramdisk/max/90kDICT32px"

def create_dataset_from_file(root, file_path):
    with open(file_path, "r") as f:
        readlines = f.readlines()
    img_paths = []
    for img_name in tqdm(readlines, desc="read dir:"):
        img_name = img_name.rstrip().strip()
        img_path = root + "/" + img_name
        if osp.exists(img_path):
            img_paths.append(img_path)
    img_paths = img_paths[:1000000]
    labels = [img_path.split("/")[-1].split("_")[-2] for img_path in tqdm(img_paths, desc="generator label:")]
    return img_paths, labels

def preprocess_label(label):
    label = label.rstrip().strip()
    w = '<start> '
    for i in label:
        w += i + ' '
    w += ' <end>'
    return w

def load_dataset(root):
    img_paths_tensor, labels = create_dataset_from_file(root, root + "/annotation_train.txt")

    labels = [label for label in labels]

    processed_labels = [preprocess_label(label) for label in tqdm(labels, desc="process label:")]

    label_lang = LanguageIndex(label for label in processed_labels)

    labels_tensor = [[label_lang.word2idx[s] for s in label.split(' ')] for label in processed_labels]

    label_max_len = max_length(labels_tensor)

    labels_tensor = tf.keras.preprocessing.sequence.pad_sequences(labels_tensor, maxlen=label_max_len, padding='post')

    return img_paths_tensor, labels_tensor, labels, label_lang, label_max_len

構建數據 dataset

def process_img(img_path):
    imread = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    imread = resize_image(imread, 100, 32)
    imread = np.expand_dims(imread, axis=-1)
    imread = np.array(imread, np.float32)
    return imread


def resize_image(image, out_width, out_height):
    """
        Resize an image to the "good" input size
    """
    im_arr = image
    h, w = np.shape(im_arr)[:2]
    ratio = out_height / h

    im_arr_resized = cv2.resize(im_arr, (int(w * ratio), out_height))
    re_h, re_w = np.shape(im_arr_resized)[:2]

    if re_w >= out_width:
        final_arr = cv2.resize(im_arr, (out_width, out_height))
    else:
        final_arr = np.ones((out_height, out_width), dtype=np.uint8) * 255
        final_arr[:, 0:np.shape(im_arr_resized)[1]] = im_arr_resized
    return final_arr

img_paths_tensor, labels_tensor, labels, label_lang, label_max_len = load_dataset(root)

BATCH_SIZE = cfg.TRAIN_BATCH_SIZE
N_BATCH = len(img_paths_tensor) // BATCH_SIZE
embedding_dim = cfg.EMBEDDING_DIM
units = cfg.UNITS

vocab_size = len(label_lang.word2idx)

def map_func(img_path_tensor, label_tensor, label):
    imread = cv2.imread(img_path_tensor.decode('utf-8'), cv2.IMREAD_GRAYSCALE)
    imread = resize_image(imread, 100, 32)
    imread = np.expand_dims(imread, axis=-1)
    imread = np.array(imread, np.float32)
    return imread, label_tensor, label


dataset = tf.data.Dataset.from_tensor_slices((img_paths_tensor, labels_tensor, labels)) \
    .map(lambda item1, item2, item3: tf.py_func(map_func, [item1, item2, item3], [tf.float32, tf.int32, tf.string]),
         num_parallel_calls=8) \
    .shuffle(10000, reshuffle_each_iteration=True)
dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)

定義Encoder、Decoder和Optimizer ,loss函數

encoder = Encoder(units, BATCH_SIZE)
decoder = Decoder(vocab_size, embedding_dim, units, BATCH_SIZE)

optimizer = tf.train.AdamOptimizer(learning_rate=0.0001)


def loss_function(real, pred):
    mask = 1 - np.equal(real, 0)
    loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask
    return tf.reduce_mean(loss_)

開啓訓練

checkpoint_dir = './checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer, encoder=encoder, decoder=decoder)

EPOCHS = 100

for epoch in range(EPOCHS):
    start = time.time()

    total_loss = 0

    for (batch, (inp, targ, ground_truths)) in enumerate(dataset):
        loss = 0

        results = np.zeros((BATCH_SIZE, targ.shape[1] - 1), np.int32)

        with tf.GradientTape() as tape:
            enc_output, enc_hidden = encoder(inp)

            dec_hidden = enc_hidden

            dec_input = tf.expand_dims([label_lang.word2idx['<start>']] * BATCH_SIZE, 1)

            # Teacher forcing - feeding the target as the next input
            for t in range(1, targ.shape[1]):
                # passing enc_output to the decoder
                predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)

                predicted_id = tf.argmax(predictions, axis=-1).numpy()

                results[:, t - 1] = predicted_id

                # result = [result[i] + label_lang.idx2word[predicted_id[i]] for i in range(BATCH_SIZE)]

                loss += loss_function(targ[:, t], predictions)

                # using teacher forcing
                dec_input = tf.expand_dims(targ[:, t], 1)

        batch_loss = (loss / int(targ.shape[1]))

        total_loss += batch_loss

        variables = encoder.variables + decoder.variables

        gradients = tape.gradient(loss, variables)

        optimizer.apply_gradients(zip(gradients, variables))

        preds = [process_result(result, label_lang) for result in results]

        ground_truths = [l.numpy().decode() for l in ground_truths]

        acc = compute_accuracy(ground_truths, preds)

        if batch % 1 == 0:
            print('Epoch {} Batch {} Loss {:.4f} Mean Loss {:.4f} acc {:f}'.format(epoch + 1, batch,
                                                                                   batch_loss.numpy(),
                                                                                   total_loss / (batch + 1),
                                                                                   acc))
        if batch % 10 == 0:
            for i in range(5):
                print("real:{:s}  pred:{:s} acc:{:f}".format(ground_truths[i], preds[i],
                                                             compute_accuracy([ground_truths[i]], [preds[i]])))

    # saving (checkpoint) the model every 2 epochs
    if (epoch + 1) % 2 == 0:
        checkpoint.save(file_prefix=checkpoint_prefix)

    print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))

測試代碼

import os

from config import cfg
from lang_dict.lang import LanguageIndex
from net.net import *
from utils.img_utils import *

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

label_lang = LanguageIndex()
vocab_size = len(label_lang.word2idx)

BATCH_SIZE = 1
embedding_dim = cfg.EMBEDDING_DIM
units = cfg.UNITS

encoder = Encoder(units, BATCH_SIZE)
decoder = Decoder(vocab_size, embedding_dim, units, BATCH_SIZE)

checkpoint_dir = './checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(encoder=encoder, decoder=decoder)

checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))


def evaluate(encoder, decoder, img_path, label_lang):
    img = process_img(img_path)

    enc_output, enc_hidden = encoder(np.expand_dims(img, axis=0))

    dec_hidden = enc_hidden

    dec_input = tf.expand_dims([label_lang.word2idx['<start>']] * BATCH_SIZE, 1)

    results = np.zeros((BATCH_SIZE, 25), np.int32)

    for t in range(1, 25):
        # passing enc_output to the decoder
        predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)

        predicted_id = tf.argmax(predictions, axis=-1).numpy()

        results[:, t - 1] = predicted_id

        dec_input = tf.expand_dims(predicted_id, 1)

    preds = [process_result(result, label_lang) for result in results]

    print("pred :" + preds[0])


img_path = "./sample/1_bridleway_9530.jpg"

evaluate(encoder=encoder, decoder=decoder, img_path=img_path, label_lang=label_lang)

添加attention後,crnn收斂非常迅速,基本一個epoch就能基本收斂

 

全部代碼

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章