更新流程↓
Task01:赛题理解
Task02:数据读取与数据扩增
Task03:字符识别模型
Task04:模型训练与验证
Task05:模型集成
好孩子看不见
比赛链接
文章目录
0. BaseLine思路
将被识别图片剪裁转化为多个规范的单字符进行识别后拼接。本次赛事使用SVHN数据集对卷积神经网络(CNN)模型进行训练。
具体包含以下步骤:
- 创建虚拟环境
- 赛题数据读取(封装为Pytorch的Dataset和DataLoder)
- 构建CNN模型(使用Pytorch搭建)
- 模型训练与验证
- 模型结果预测
0.1. 创建虚拟环境
0.1.1. Anaconda3安装
具体操作可以上B站观看教程,在此不再赘述。
0.1.2. 使用Anaconda3创建虚拟环境
0. 以下所有操作基于Windows 10系统下完成,本次构建环境所用到的是python3.7 + torch1.3.1gpu 版本。
1. 在安装好Anaconda3之后,我们在开始菜单栏找到并运行Anaconda Navigator。
2. 点击Anaconda Navigator左侧的 环境(Environments) 窗口,显示出的列表中存在一个base(root)环境。我们点击下方的 +号(Create) ,在弹出的窗口中,Packages选项勾选Python并选择版本,Name栏输入所要被创建的环境名字,在本次赛事中环境命名为:py37_torch131 。点击Create既可开始创建,等待些许时间就可在环境“base”下面看见新创建的“py37_torch131”。
3. 鼠标选中“py37_torch131”,点击名字右侧的开始按钮,选择open terminal,既可激活环境并弹出CMD命令提示符。
4. 输入conda install pytorch=1.3.1 torchvision cudatoolkit=10.0
来安装pytorch1.3.1。(注:若因为下载速度缓慢而失败,可以选择使用清华镜像源)
5. 输入pip install jupyter tqdm opencv-python matplotlib pandas
一键安装所需其它依赖库。
6. 输入jupyter notebook
来启动JupyterNotebook进行代码编译。
至此,虚拟环境已经创建完毕,可以进行代码编写。
1. 赛题理解
1.1. 赛题数据
进入赛事界面的赛事与数据,从文件中下载数据并解压。其中,训练集数据train包括3W张照片,验证集数据val包括1W张照片,测试集test包括4W张照片。
1.2. 数据标签
训练集和验证集的标签使用 .JSON格式。对于数据集中每张图片将给出对应的编码标签,和具体的字符框的位置,可用于模型训练:
top | height | left | width | label |
---|---|---|---|---|
左上角座标X | 字符高度 | 左上角最表Y | 字符宽度 | 字符编码 |
因为数据集中图片是含有多个字符的,所以提供的数据会包含多个字符的边框信息
例如:
1.3. 字符识别方法
数据集中图片包含的字符个数为2-6个,因此我们需要对不定长的字符进行识别,目前较为常见的有以下三种解决本问题的思路。
1.3.1. 简单入门思路:定长字符识别
可以将赛题抽象为一个定长字符识别问题,在数据集中最多的字符个数为6个。因此可以对于所有的图像都抽象为6个字符的识别问题,字符23填充为23XXXX,字符231填充为231XXX。
处理之后原始的赛题转化为6个字符的分类问题。每张图片会进行6次11种判别的分类(0到9以及为null的X),若判别为X则表明该字符及之后字符都为空。
1.3.2. 专业字符识别思路:不定长字符识别
在字符识别研究中,有特定的方法来解决此种不定长的字符识别问题,比较典型的有CRNN字符识别模型。在本次赛题中给定的图像数据都比较规整,可以视为一个单词或者一个句子。
1.3.3. 专业分类思路:检测再识别
在赛题数据中已经给出了训练集、验证集中所有图片中字符的位置,因此可以首先将字符的位置进行识别,利用物体检测的思路完成。
此种思路需要参赛选手构建字符检测模型,对测试集中的字符进行识别。选手可以参考物体检测模型SSD或者YOLO来完成。
1.4. 读取数据
import json
import cv2
import numpy as np
import matplotlib.pyplot as plt
# 数据标注处理
def parse_json(d):
arr = np.array([d['top'], d['height'], d['left'], d['width'], d['label']])
arr = arr.astype(int)
return arr
#训练集标签载入
train_json = json.load(open('../input/train.json'))
img = cv2.imread('../input/train/000000.png')
arr = parse_json(train_json['000000.png'])
#图片分割
for idx in range(arr.shape[1]):
plt.subplot(1, arr.shape[1]+1, idx+2)
plt.imshow(img[arr[0, idx]:arr[0, idx]+arr[1, idx],arr[2, idx]:arr[2, idx]+arr[3, idx]])
plt.title(arr[4, idx])
plt.xticks([]); plt.yticks([])
1.5. 评测指标
以标签整体识别准确率为评价指标。
一张图片的识别结果与其标签完全相同即为正确,其中任何一个字符不同都视为是错误。
最终正确率具体计算公式如下: