Datawhale 筆記1-數據讀取與擴增
數據讀取
- Dataset:對數據集的封裝,提供索引方式的對數據樣本進行讀取 (單個樣本)
- DataLoder:對Dataset進行封裝,提供批量讀取的迭代讀取
- 使用torchvision的transform方法進行數據增強;torchvision不僅包括數據增強,還包括數據轉換,是否送入tensor。
# # 定義讀取數據集
class SVHNDataset(Dataset): #重載數據子類
def __init__(self, img_path, img_label, transform=None): #定義初始化方法
self.img_path = img_path
self.img_label = img_label
if transform is not None:
self.transform = transform
else:
self.transform = None
def __getitem__(self, index):
img = Image.open(self.img_path[index]).convert('RGB')
if self.transform is not None:
img = self.transform(img)
# 數據索引的邏輯
lbl = np.array(self.img_label[index], dtype=np.int)
lbl = list(lbl) + (5 - len(lbl)) * [10]
return img, torch.from_numpy(np.array(lbl[:5]))
def __len__(self):
return len(self.img_path)
# # 定義讀取數據dataloader
# 假設數據存放在`../input`文件夾下,並進行解壓。
train_path = glob.glob('../data/train/mchar_train/*.png')
train_path.sort()
train_json = json.load(open('../data/train/mchar_train.json'))
train_label = [train_json[x]['label'] for x in train_json]
print(len(train_path), len(train_label))
train_loader = torch.utils.data.DataLoader(
SVHNDataset(train_path, train_label,
transforms.Compose([
transforms.Resize((64, 128)#將圖片的尺寸轉換爲64×128
transforms.RandomCrop((60, 120)),#沿圖片中心裁剪,裁剪的大小是60×120
transforms.ColorJitter(0.3, 0.3, 0.2),#對圖像顏色的對比度、飽和度和零度進行變換
transforms.RandomRotation(10),#對圖像進行隨機旋轉
transforms.ToTensor(),# 將圖像轉換爲取值範圍爲[0,1.0]的torch.FloadTensor
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#對圖片的像素進行正則化
])),
batch_size=40, #每批樣本個數
shuffle=True, #是否打亂順序
num_workers=0,#讀取的線程個數
)