賽事筆記2

Datawhale 筆記1-數據讀取與擴增

數據讀取

  • Dataset:對數據集的封裝,提供索引方式的對數據樣本進行讀取 (單個樣本)
  • DataLoder:對Dataset進行封裝,提供批量讀取的迭代讀取
  • 使用torchvision的transform方法進行數據增強;torchvision不僅包括數據增強,還包括數據轉換,是否送入tensor。
# # 定義讀取數據集
class SVHNDataset(Dataset): #重載數據子類
    def __init__(self, img_path, img_label, transform=None): #定義初始化方法
        self.img_path = img_path
        self.img_label = img_label 
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None

    def __getitem__(self, index):
        img = Image.open(self.img_path[index]).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)
        # 數據索引的邏輯
        lbl = np.array(self.img_label[index], dtype=np.int)
        lbl = list(lbl)  + (5 - len(lbl)) * [10]
        return img, torch.from_numpy(np.array(lbl[:5]))

    def __len__(self):
        return len(self.img_path)

# # 定義讀取數據dataloader
# 假設數據存放在`../input`文件夾下,並進行解壓。

train_path = glob.glob('../data/train/mchar_train/*.png')
train_path.sort()
train_json = json.load(open('../data/train/mchar_train.json'))
train_label = [train_json[x]['label'] for x in train_json]
print(len(train_path), len(train_label))

train_loader = torch.utils.data.DataLoader(
    SVHNDataset(train_path, train_label,
                transforms.Compose([
                    transforms.Resize((64, 128)#將圖片的尺寸轉換爲64×128
                    transforms.RandomCrop((60, 120)),#沿圖片中心裁剪,裁剪的大小是60×120
                    transforms.ColorJitter(0.3, 0.3, 0.2),#對圖像顏色的對比度、飽和度和零度進行變換
                    transforms.RandomRotation(10),#對圖像進行隨機旋轉
                    transforms.ToTensor(),# 將圖像轉換爲取值範圍爲[0,1.0]的torch.FloadTensor
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#對圖片的像素進行正則化
    ])), 
    batch_size=40, #每批樣本個數
    shuffle=True, #是否打亂順序
    num_workers=0,#讀取的線程個數
)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章