街景字符編碼識別cv入門賽-02數據讀取與數據擴增

街景字符編碼識別cv入門賽-02數據讀取與數據擴增

1、數據讀取

數據集包括訓練集(30000張,像素值不等),驗證集(10000張)和測試集(40000張)
用pytorch 寫一個datasets的類

from torch.utils.data.datasets import Dataset
class SVHNDataset(Dataset):
 	# __init__ 包括圖片的路徑、標籤、圖片的變換器 
	def __init__(self, img_path, img_label, transform=None):
		self.img_path = img_path
		self.img_label = img_label
		if transform is not None:
			self.transform = transform
		else:
			self.transform = None
	# index爲圖片的索引 __getitem__返回圖片和標籤(Tensor格式)
	def __getitem__(self, index):
		img = cv2.imread(img_path[index])   # opencv 讀取的img爲BGR格式
		img = cv2.cvtCOLOR(img, cv2.COLOR_BGR2RGB) # 轉換爲RGB格式
		# 這裏爲albumentation 包的讀取形式
		if transform is not None:
			augmented = self.transform(image=img)
			img = augmented["image"]
		
		lbl = np.array(self.img_label[index], dtype=np.int)
		lbl = list[lbl] + [10]*(5-len(lbl))

	return img, torch.from_numpy(np.array(lbl))     # torch.from_numpy:將np.array 轉爲 torch中的tensor類型
	
	def __len__(self):
		return(len(self.img_path))
		

2、 數據擴增

數據擴增用的爲albumentation 庫 ,這個庫裏面的數據擴增方法十分豐富,而且處理速度較快,可以無縫嵌入到pytroch中,非常的方便,具體的用法可以看github上的示例代碼

from albumentation.pytorch import ToTensor
from albumentations import (Compose, Resize, RandomCrop, HueSaturationValue, Rotate, Normalize, Cutout, GaussNoise, ElasticTransform)

transform_train = Compose([
	ToTensor(),       # img/255 歸一化操作
	Normalize(
	        mean=[0.485, 0.456, 0.406],
	        std=[0.229, 0.224, 0.225],      # img中的channel=channel-mean/std  
	    ),
	Resize(64,128),
	# 各種數據擴增操作
	ElasticTransform(),
    Cutout(),		# 加入方塊的馬斯克
    GaussNoise(),	# 高斯噪聲
    Resize(64, 128),  # resize
#   RandomCrop(60, 120),	# 隨機裁剪
    HueSaturationValue(0.3, 0.3, 0.2),   # 飽和度
    Rotate(10),  # 旋轉10°
])

transform_val = Compose([
	ToTensor(),
	Normalize(
	        mean=[0.485, 0.456, 0.406],
	        std=[0.229, 0.224, 0.225],      # img中的channel=channel-mean/std  
	# 驗證集中不進行數據擴增
])

數據擴增後的圖片效果

# 圖片擴增前
import matplotlib.pyplot as plt
image_path = train_path[0]
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(10, 10))
plt.imshow(image)
plt.show()

# 擴增後
from albumentations import (Cutout, HorizontalFlip, GaussNoise)
def augment_and_show(aug, image):
    image = aug(image=image)['image']
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    
aug = GaussNoise(p=1)
augment_and_show(aug, image)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章