更新流程↓
Task01:賽題理解
Task02:數據讀取與數據擴增
Task03:字符識別模型
Task04:模型訓練與驗證
Task05:模型集成
好孩子看不見
比賽鏈接
文章目錄
0. BaseLine思路
將被識別圖片剪裁轉化爲多個規範的單字符進行識別後拼接。本次賽事使用SVHN數據集對卷積神經網絡(CNN)模型進行訓練。
具體包含以下步驟:
- 創建虛擬環境
- 賽題數據讀取(封裝爲Pytorch的Dataset和DataLoder)
- 構建CNN模型(使用Pytorch搭建)
- 模型訓練與驗證
- 模型結果預測
0.1. 創建虛擬環境
0.1.1. Anaconda3安裝
具體操作可以上B站觀看教程,在此不再贅述。
0.1.2. 使用Anaconda3創建虛擬環境
0. 以下所有操作基於Windows 10系統下完成,本次構建環境所用到的是python3.7 + torch1.3.1gpu 版本。
1. 在安裝好Anaconda3之後,我們在開始菜單欄找到並運行Anaconda Navigator。
2. 點擊Anaconda Navigator左側的 環境(Environments) 窗口,顯示出的列表中存在一個base(root)環境。我們點擊下方的 +號(Create) ,在彈出的窗口中,Packages選項勾選Python並選擇版本,Name欄輸入所要被創建的環境名字,在本次賽事中環境命名爲:py37_torch131 。點擊Create既可開始創建,等待些許時間就可在環境“base”下面看見新創建的“py37_torch131”。
3. 鼠標選中“py37_torch131”,點擊名字右側的開始按鈕,選擇open terminal,既可激活環境並彈出CMD命令提示符。
4. 輸入conda install pytorch=1.3.1 torchvision cudatoolkit=10.0
來安裝pytorch1.3.1。(注:若因爲下載速度緩慢而失敗,可以選擇使用清華鏡像源)
5. 輸入pip install jupyter tqdm opencv-python matplotlib pandas
一鍵安裝所需其它依賴庫。
6. 輸入jupyter notebook
來啓動JupyterNotebook進行代碼編譯。
至此,虛擬環境已經創建完畢,可以進行代碼編寫。
1. 賽題理解
1.1. 賽題數據
進入賽事界面的賽事與數據,從文件中下載數據並解壓。其中,訓練集數據train包括3W張照片,驗證集數據val包括1W張照片,測試集test包括4W張照片。
1.2. 數據標籤
訓練集和驗證集的標籤使用 .JSON格式。對於數據集中每張圖片將給出對應的編碼標籤,和具體的字符框的位置,可用於模型訓練:
top | height | left | width | label |
---|---|---|---|---|
左上角座標X | 字符高度 | 左上角最表Y | 字符寬度 | 字符編碼 |
因爲數據集中圖片是含有多個字符的,所以提供的數據會包含多個字符的邊框信息
例如:
1.3. 字符識別方法
數據集中圖片包含的字符個數爲2-6個,因此我們需要對不定長的字符進行識別,目前較爲常見的有以下三種解決本問題的思路。
1.3.1. 簡單入門思路:定長字符識別
可以將賽題抽象爲一個定長字符識別問題,在數據集中最多的字符個數爲6個。因此可以對於所有的圖像都抽象爲6個字符的識別問題,字符23填充爲23XXXX,字符231填充爲231XXX。
處理之後原始的賽題轉化爲6個字符的分類問題。每張圖片會進行6次11種判別的分類(0到9以及爲null的X),若判別爲X則表明該字符及之後字符都爲空。
1.3.2. 專業字符識別思路:不定長字符識別
在字符識別研究中,有特定的方法來解決此種不定長的字符識別問題,比較典型的有CRNN字符識別模型。在本次賽題中給定的圖像數據都比較規整,可以視爲一個單詞或者一個句子。
1.3.3. 專業分類思路:檢測再識別
在賽題數據中已經給出了訓練集、驗證集中所有圖片中字符的位置,因此可以首先將字符的位置進行識別,利用物體檢測的思路完成。
此種思路需要參賽選手構建字符檢測模型,對測試集中的字符進行識別。選手可以參考物體檢測模型SSD或者YOLO來完成。
1.4. 讀取數據
import json
import cv2
import numpy as np
import matplotlib.pyplot as plt
# 數據標註處理
def parse_json(d):
arr = np.array([d['top'], d['height'], d['left'], d['width'], d['label']])
arr = arr.astype(int)
return arr
#訓練集標籤載入
train_json = json.load(open('../input/train.json'))
img = cv2.imread('../input/train/000000.png')
arr = parse_json(train_json['000000.png'])
#圖片分割
for idx in range(arr.shape[1]):
plt.subplot(1, arr.shape[1]+1, idx+2)
plt.imshow(img[arr[0, idx]:arr[0, idx]+arr[1, idx],arr[2, idx]:arr[2, idx]+arr[3, idx]])
plt.title(arr[4, idx])
plt.xticks([]); plt.yticks([])
1.5. 評測指標
以標籤整體識別準確率爲評價指標。
一張圖片的識別結果與其標籤完全相同即爲正確,其中任何一個字符不同都視爲是錯誤。
最終正確率具體計算公式如下: