街景字符编码识别cv入门赛-02数据读取与数据扩增
1、数据读取
数据集包括训练集(30000张,像素值不等),验证集(10000张)和测试集(40000张)
用pytorch 写一个datasets的类
from torch.utils.data.datasets import Dataset
class SVHNDataset(Dataset):
# __init__ 包括图片的路径、标签、图片的变换器
def __init__(self, img_path, img_label, transform=None):
self.img_path = img_path
self.img_label = img_label
if transform is not None:
self.transform = transform
else:
self.transform = None
# index为图片的索引 __getitem__返回图片和标签(Tensor格式)
def __getitem__(self, index):
img = cv2.imread(img_path[index]) # opencv 读取的img为BGR格式
img = cv2.cvtCOLOR(img, cv2.COLOR_BGR2RGB) # 转换为RGB格式
# 这里为albumentation 包的读取形式
if transform is not None:
augmented = self.transform(image=img)
img = augmented["image"]
lbl = np.array(self.img_label[index], dtype=np.int)
lbl = list[lbl] + [10]*(5-len(lbl))
return img, torch.from_numpy(np.array(lbl)) # torch.from_numpy:将np.array 转为 torch中的tensor类型
def __len__(self):
return(len(self.img_path))
2、 数据扩增
数据扩增用的为albumentation 库 ,这个库里面的数据扩增方法十分丰富,而且处理速度较快,可以无缝嵌入到pytroch中,非常的方便,具体的用法可以看github上的示例代码
from albumentation.pytorch import ToTensor
from albumentations import (Compose, Resize, RandomCrop, HueSaturationValue, Rotate, Normalize, Cutout, GaussNoise, ElasticTransform)
transform_train = Compose([
ToTensor(), # img/255 归一化操作
Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225], # img中的channel=channel-mean/std
),
Resize(64,128),
# 各种数据扩增操作
ElasticTransform(),
Cutout(), # 加入方块的马斯克
GaussNoise(), # 高斯噪声
Resize(64, 128), # resize
# RandomCrop(60, 120), # 随机裁剪
HueSaturationValue(0.3, 0.3, 0.2), # 饱和度
Rotate(10), # 旋转10°
])
transform_val = Compose([
ToTensor(),
Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225], # img中的channel=channel-mean/std
# 验证集中不进行数据扩增
])
数据扩增后的图片效果
# 图片扩增前
import matplotlib.pyplot as plt
image_path = train_path[0]
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(10, 10))
plt.imshow(image)
plt.show()
# 扩增后
from albumentations import (Cutout, HorizontalFlip, GaussNoise)
def augment_and_show(aug, image):
image = aug(image=image)['image']
plt.figure(figsize=(10, 10))
plt.imshow(image)
aug = GaussNoise(p=1)
augment_and_show(aug, image)