街景字符编码识别cv入门赛-02数据读取与数据扩增

街景字符编码识别cv入门赛-02数据读取与数据扩增

1、数据读取

数据集包括训练集(30000张,像素值不等),验证集(10000张)和测试集(40000张)
用pytorch 写一个datasets的类

from torch.utils.data.datasets import Dataset
class SVHNDataset(Dataset):
 	# __init__ 包括图片的路径、标签、图片的变换器 
	def __init__(self, img_path, img_label, transform=None):
		self.img_path = img_path
		self.img_label = img_label
		if transform is not None:
			self.transform = transform
		else:
			self.transform = None
	# index为图片的索引 __getitem__返回图片和标签(Tensor格式)
	def __getitem__(self, index):
		img = cv2.imread(img_path[index])   # opencv 读取的img为BGR格式
		img = cv2.cvtCOLOR(img, cv2.COLOR_BGR2RGB) # 转换为RGB格式
		# 这里为albumentation 包的读取形式
		if transform is not None:
			augmented = self.transform(image=img)
			img = augmented["image"]
		
		lbl = np.array(self.img_label[index], dtype=np.int)
		lbl = list[lbl] + [10]*(5-len(lbl))

	return img, torch.from_numpy(np.array(lbl))     # torch.from_numpy:将np.array 转为 torch中的tensor类型
	
	def __len__(self):
		return(len(self.img_path))
		

2、 数据扩增

数据扩增用的为albumentation 库 ,这个库里面的数据扩增方法十分丰富,而且处理速度较快,可以无缝嵌入到pytroch中,非常的方便,具体的用法可以看github上的示例代码

from albumentation.pytorch import ToTensor
from albumentations import (Compose, Resize, RandomCrop, HueSaturationValue, Rotate, Normalize, Cutout, GaussNoise, ElasticTransform)

transform_train = Compose([
	ToTensor(),       # img/255 归一化操作
	Normalize(
	        mean=[0.485, 0.456, 0.406],
	        std=[0.229, 0.224, 0.225],      # img中的channel=channel-mean/std  
	    ),
	Resize(64,128),
	# 各种数据扩增操作
	ElasticTransform(),
    Cutout(),		# 加入方块的马斯克
    GaussNoise(),	# 高斯噪声
    Resize(64, 128),  # resize
#   RandomCrop(60, 120),	# 随机裁剪
    HueSaturationValue(0.3, 0.3, 0.2),   # 饱和度
    Rotate(10),  # 旋转10°
])

transform_val = Compose([
	ToTensor(),
	Normalize(
	        mean=[0.485, 0.456, 0.406],
	        std=[0.229, 0.224, 0.225],      # img中的channel=channel-mean/std  
	# 验证集中不进行数据扩增
])

数据扩增后的图片效果

# 图片扩增前
import matplotlib.pyplot as plt
image_path = train_path[0]
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(10, 10))
plt.imshow(image)
plt.show()

# 扩增后
from albumentations import (Cutout, HorizontalFlip, GaussNoise)
def augment_and_show(aug, image):
    image = aug(image=image)['image']
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    
aug = GaussNoise(p=1)
augment_and_show(aug, image)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章