街景字符編碼識別cv入門賽-02數據讀取與數據擴增
1、數據讀取
數據集包括訓練集(30000張,像素值不等),驗證集(10000張)和測試集(40000張)
用pytorch 寫一個datasets的類
from torch.utils.data.datasets import Dataset
class SVHNDataset(Dataset):
# __init__ 包括圖片的路徑、標籤、圖片的變換器
def __init__(self, img_path, img_label, transform=None):
self.img_path = img_path
self.img_label = img_label
if transform is not None:
self.transform = transform
else:
self.transform = None
# index爲圖片的索引 __getitem__返回圖片和標籤(Tensor格式)
def __getitem__(self, index):
img = cv2.imread(img_path[index]) # opencv 讀取的img爲BGR格式
img = cv2.cvtCOLOR(img, cv2.COLOR_BGR2RGB) # 轉換爲RGB格式
# 這裏爲albumentation 包的讀取形式
if transform is not None:
augmented = self.transform(image=img)
img = augmented["image"]
lbl = np.array(self.img_label[index], dtype=np.int)
lbl = list[lbl] + [10]*(5-len(lbl))
return img, torch.from_numpy(np.array(lbl)) # torch.from_numpy:將np.array 轉爲 torch中的tensor類型
def __len__(self):
return(len(self.img_path))
2、 數據擴增
數據擴增用的爲albumentation 庫 ,這個庫裏面的數據擴增方法十分豐富,而且處理速度較快,可以無縫嵌入到pytroch中,非常的方便,具體的用法可以看github上的示例代碼
from albumentation.pytorch import ToTensor
from albumentations import (Compose, Resize, RandomCrop, HueSaturationValue, Rotate, Normalize, Cutout, GaussNoise, ElasticTransform)
transform_train = Compose([
ToTensor(), # img/255 歸一化操作
Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225], # img中的channel=channel-mean/std
),
Resize(64,128),
# 各種數據擴增操作
ElasticTransform(),
Cutout(), # 加入方塊的馬斯克
GaussNoise(), # 高斯噪聲
Resize(64, 128), # resize
# RandomCrop(60, 120), # 隨機裁剪
HueSaturationValue(0.3, 0.3, 0.2), # 飽和度
Rotate(10), # 旋轉10°
])
transform_val = Compose([
ToTensor(),
Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225], # img中的channel=channel-mean/std
# 驗證集中不進行數據擴增
])
數據擴增後的圖片效果
# 圖片擴增前
import matplotlib.pyplot as plt
image_path = train_path[0]
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(10, 10))
plt.imshow(image)
plt.show()
# 擴增後
from albumentations import (Cutout, HorizontalFlip, GaussNoise)
def augment_and_show(aug, image):
image = aug(image=image)['image']
plt.figure(figsize=(10, 10))
plt.imshow(image)
aug = GaussNoise(p=1)
augment_and_show(aug, image)