【AI實戰】手把手教你實現文字識別模型(入門篇:驗證碼識別)

文字識別在現實生活中有着非常重要的應用,主要由文字檢測、內容識別兩個關鍵步驟組成,在本博客之前的文章中已介紹了文字檢測、內容識別的經典模型原理(見文章:大話文本檢測經典模型:CTPN , 大話文本識別經典模型:CRNN),本文主要從實戰的角度介紹如何實現文字識別模型。

在之前的文章中,已經介紹過了跟文字識別相關的實戰內容:基於MNIST數據集識別手寫數字的實戰內容(見文章:訓練你的第一個AI模型:MNIST手寫數字識別模型),這個相對簡單。今天再介紹文字識別的另一個經典應用:驗證碼識別,作爲文字識別的實戰入門篇。

 

驗證碼在手機APP、WEB網站中非常普遍,主要是爲了防止惡意登錄、刷票、灌水、爬蟲等異常行爲,也可能是爲了緩解系統的後臺壓力(例如在秒殺、搶票時,強制要求輸入驗證碼)。本文主要介紹文本型驗證碼的識別,文本型驗證碼由數字、英文大小寫字母,甚至中文隨機組成,再進行變形扭曲、加干擾線、加背景噪音等操作,主要是爲了防止被光學字符識別(OCR)之類的程序自動識別出圖片上的文字而失去效果,如下圖:

由於存在着比較強的干擾信息,因此,直接使用OCR進行識別,效果很不理想,而通過AI可很好地實現這種複雜信息的識別。目前百度等AI開放平臺,也提供了驗證碼識別的開放接口,但由於驗證碼可由各APP、網站根據任意自定的規則隨機組合生成,因此,這些AI平臺的驗證碼識別開放接口在某些場景下效果很好,在某些場景下可能就失靈了。針對具體的場景,我們通過自己訓練驗證碼識別的AI模型,能很好地解決該場景下的驗證碼識別問題。

 

下面開始介紹使用Tensorflow構建驗證碼的識別模型,主要步驟如下:

  • step 1. 獲取驗證碼圖片
  • step 2. 圖片標註
  • step 3. 訓練模型
  • step 4. 模型應用

 

1、獲取驗證碼圖片

(1)如果是自己練習的,可直接隨機生成驗證碼圖片作爲基礎數據集。在python裏面使用captcha庫來快速生成驗證碼圖片,通過pip install captcha進行安裝,或者手動下載captcha-0.3-py3-none-any.whl文件進行安裝。(注:anaconda無法通過conda install 直接安裝captcha,但可使用anaconda裏面的pip來安裝captcha),核心代碼如下:

from captcha.image import ImageCaptcha
import random

# 生成驗證碼的字符集
CHAR_SET = ['0','1','2','3','4','5','6','7','8','9']
CHAR_SET_LEN = len(CHAR_SET)

# 驗證碼長度
CAPTCHA_LEN  = 4

for i in range(CHAR_SET_LEN):
    for j in range(CHAR_SET_LEN):
        for k in range(CHAR_SET_LEN):
            for l in range(CHAR_SET_LEN):
                captcha_text = CHAR_SET[i] + CHAR_SET[j] + CHAR_SET[k] + CHAR_SET[l]
                image = ImageCaptcha()
                image.write(captcha_text, '/tmp/mydata/' + captcha_text + '.jpg')

生成的效果如下圖

(2)如果是要針對某個網站的驗證碼進行識別的,則可使用一些工具將對應的驗證碼下載下來。一般網站登錄的界面如下:

其中,通常可直接點擊驗證碼圖片,或旁邊的“換一張”按鈕,更換驗證碼圖片。這時,可使用像“按鍵精靈”之類的模擬鼠標操作的軟件,錄製一段腳本,然後在驗證碼圖片處模擬右鍵鼠標保存圖片,再點擊驗證碼圖片更換新的驗證碼,如此反覆,即可下載該網站的大量驗證碼圖片,用於訓練模型。至於這個下載驗證碼圖片的腳本嘛,爲了不教壞大家,此處省略500字,嘿嘿~

 

2、圖片標註

如果第1步是自己隨機生成驗證碼圖片的,那麼在保存圖片時,文件名便是該驗證碼圖片的文本內容,無須再進行標註。

如果第1步是下載了某個網站的驗證碼圖片的,那麼需要先人工對驗證碼圖片的文本內容進行標註,以方便接下來的模型訓練。可通過觀察,將驗證碼圖片的文本信息記在文件名中(重命名),通過這種方式進行圖片標註,也可以單獨記錄在文本文件中。

 

3、訓練模型

(1)標籤one-hot編碼

爲了能夠將驗證碼圖片的文本信息輸入到卷積神經網絡模型裏面去訓練,需要將文本信息向量化編碼。在這裏使用“熱獨編碼”(one-hot),即使用01編碼表示文本信息。本項目的驗證碼文本長度爲4位,驗證碼編碼由0至9的數字組成,例如驗證碼文本信息爲“1086”,則one-hot編碼時在相應的位置標爲1,其餘爲0,如下圖

則“1086”經one-hot編碼後變爲[0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0] 。將驗證碼文本信息進行one-hot編碼的核心代碼如下:

def text2label(text):
    label = np.zeros(CAPTCHA_LEN * CHAR_SET_LEN)
    for i in range(len(text)):
        idx = i * CHAR_SET_LEN + CHAR_SET.index(text[i])
        label[idx] = 1
    return label

(2)讀取圖片文件

讀取驗證碼圖片、驗證碼文本內容(保存在文件名中),並編寫獲取下個批量數據的方法,主要函數如下:


# 獲取驗證碼圖片路徑及文本內容
def get_image_file_name(img_path):
    img_files = []
    img_labels = []
    for root, dirs, files in os.walk(img_path):
        for file in files:
            if os.path.splitext(file)[1] == '.jpg':
                img_files.append(root+'/'+file)
                img_labels.append(text2label(os.path.splitext(file)[0]))
    return img_files,img_labels

# 批量獲取數據
def get_next_batch(img_files,img_labels,batch_size):
    batch_x = np.zeros([batch_size, IMAGE_WIDTH*IMAGE_HEIGHT])
    batch_y = np.zeros([batch_size, CAPTCHA_LEN * CHAR_SET_LEN])

    for i in range(batch_size):
        idx = random.randint(0, len(img_files) - 1)
        file_path = img_files[idx]
        image = cv2.imread(file_path)
        image = cv2.resize(image, (IMAGE_WIDTH, IMAGE_HEIGHT))
        image = image.astype(np.float32)
        image = np.multiply(image, 1.0 / 255.0)
        batch_x[i, :] = image
        batch_y[i, :] = img_labels[idx]

    return batch_x,batch_y

(3)構建CNN模型

由於驗證碼的識別相對比較簡單,借鑑LeNet的網絡結構構建CNN模型,由3個卷積層和1個全連接層組成,網絡結構圖如下:

核心代碼如下:

# 圖像尺寸
IMAGE_HEIGHT = 60
IMAGE_WIDTH = 160

# 網絡相關變量
X = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT * IMAGE_WIDTH])
Y = tf.placeholder(tf.float32, [None, CAPTCHA_LEN * CHAR_SET_LEN])
keep_prob = tf.placeholder(tf.float32)  # dropout

# 驗證碼 CNN 網絡
def crack_captcha_cnn_network (w_alpha=0.01, b_alpha=0.1):
    x = tf.reshape(X, shape=[-1, IMAGE_HEIGHT, IMAGE_WIDTH, 1])

    w_c1 = tf.Variable(w_alpha * tf.random_normal([3, 3, 1, 32]))
    b_c1 = tf.Variable(b_alpha * tf.random_normal([32]))
    conv1 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(x, w_c1, strides=[1, 1, 1, 1], padding='SAME'), b_c1))
    conv1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    conv1 = tf.nn.dropout(conv1, keep_prob)

    w_c2 = tf.Variable(w_alpha * tf.random_normal([3, 3, 32, 64]))
    b_c2 = tf.Variable(b_alpha * tf.random_normal([64]))
    conv2 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv1, w_c2, strides=[1, 1, 1, 1], padding='SAME'), b_c2))
    conv2 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    conv2 = tf.nn.dropout(conv2, keep_prob)

    w_c3 = tf.Variable(w_alpha * tf.random_normal([3, 3, 64, 64]))
    b_c3 = tf.Variable(b_alpha * tf.random_normal([64]))
    conv3 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv2, w_c3, strides=[1, 1, 1, 1], padding='SAME'), b_c3))
    conv3 = tf.nn.max_pool(conv3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    conv3 = tf.nn.dropout(conv3, keep_prob)

    w_d = tf.Variable(w_alpha * tf.random_normal([8 * 20 * 64, 1024]))
    b_d = tf.Variable(b_alpha * tf.random_normal([1024]))
    dense = tf.reshape(conv3, [-1, w_d.get_shape().as_list()[0]])
    dense = tf.nn.relu(tf.add(tf.matmul(dense, w_d), b_d))
    dense = tf.nn.dropout(dense, keep_prob)

    w_out = tf.Variable(w_alpha * tf.random_normal([1024, CAPTCHA_LEN * CHAR_SET_LEN]))
    b_out = tf.Variable(b_alpha * tf.random_normal([CAPTCHA_LEN * CHAR_SET_LEN]))
    out = tf.add(tf.matmul(dense, w_out), b_out)
    return out

(4)訓練模型

通過設置好模型訓練的迭代輪次、批量獲取樣本數量、學習率等參數,讀取驗證碼圖片集,並隨機劃分出訓練集、測試集,再加載本項目的網絡模型進行訓練,每100步評估一次準確率和保存模型文件。核心代碼如下:

# 模型的相關參數
step_cnt = 200000  # 迭代輪數
batch_size = 16  # 批量獲取樣本數量
learning_rate = 0.0001  # 學習率

# 讀取驗證碼圖片集
img_path = '/tmp/mydata/'
img_files, img_labels = get_image_file_name(img_path)

# 劃分出訓練集、測試集
x_train,x_test,y_train,y_test=train_test_split(img_files,img_labels,test_size=0.2,random_state=33)

# 加載網絡結構
output = crack_captcha_cnn_network()

# 損失函數、優化器
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=output, labels=Y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)

# 評估準確率
predict = tf.reshape(output, [-1, CAPTCHA_LEN, CHAR_SET_LEN])
max_idx_p = tf.argmax(predict, 2)
max_idx_l = tf.argmax(tf.reshape(Y, [-1, CAPTCHA_LEN, CHAR_SET_LEN]), 2)
correct_pred = tf.equal(max_idx_p, max_idx_l)
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver(tf.global_variables(), max_to_keep=5)

for step in range(step_cnt):
    # 訓練模型
        batch_x, batch_y = get_next_batch(x_train, y_train,batch_size)
        _, loss_ = sess.run([optimizer, loss], feed_dict={X: batch_x, Y: batch_y, keep_prob: 0.75})
        print('step:',step, 'loss:',loss_)

        # 每100步評估一次準確率
        if step % 100 == 0:
            batch_x_test, batch_y_test = get_next_batch(x_test, y_test,batch_size)
            acc = sess.run(accuracy, feed_dict={X: batch_x_test, Y: batch_y_test, keep_prob: 1.})
            print('step:',step,'acc:',acc)

            # 保存模型
            saver.save(sess, '/tmp/mymodel/crack_captcha.ctpk', global_step=step)

        step += 1

訓練的過程如下圖所示:

經過一段時間的訓練後,評估的準確率可達到99%以上,能非常準確地識別出驗證碼。

 

4、模型應用

通過加載訓練好後的模型文件,即可輸入圖片進行驗證碼識別,核心代碼如下:

# 加載網絡結構
output = crack_captcha_cnn_network()

saver = tf.train.Saver()
with tf.Session() as sess:
    model_path = '/tmp/mymodel/'
    saver.restore(sess, tf.train.latest_checkpoint(model_path))

    output_rate=tf.reshape(output, [-1, CAPTCHA_LEN, CHAR_SET_LEN])
    predict = tf.argmax(output_rate, 2)
    text_list,rate_list = sess.run([predict,output_rate], feed_dict={X: [captcha_image], keep_prob: 1})   # captcha_image 爲待識別的驗證碼圖片

    tmptext = text_list[0].tolist()
    text=''
    for i in range(len(tmptext)):
        text = text + CHAR_SET[tmptext[i]]

    print('識別結果:',text)

以上就是文字識別的入門實戰內容:驗證碼圖片文本識別。通過本次的學習,可瞭解簡單的文本識別的實現方式。

 

關注本人公衆號“大數據與人工智能Lab”(BigdataAILab),然後回覆“代碼”關鍵字可獲取完整的源代碼

 

推薦相關閱讀

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章