pytorch學習筆記(4):dataset dataloder 一定堅持學完啊!!

dataset與dataloder

1.dataset:

torch.utils.data.Dataset()
Dataset抽象類,所有自定義的dataset需要繼承它,
getitem:接受一個索引,返回一個樣本

class Dataset(object):
    def __init__(self, ):
    def __len__(self):
    def __getitem__(self, ):

例如:
class RMBDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        """
        :param data_dir: str, 數據集所在路徑
        :param transform: torch.transform,數據預處理
        """
        self.label_name = {"1": 0, "100": 1}
        self.data_info = self.get_img_info(data_dir)  # data_info存儲所有圖片路徑和標籤,在DataLoader中通過index讀取樣本
        self.transform = transform

    def __getitem__(self, index): #根據index索引返回img,label
        path_img, label = self.data_info[index]
        img = Image.open(path_img).convert('RGB')     # 0~255

        if self.transform is not None:
            img = self.transform(img)   # 在這裏做transform,轉爲tensor等等

        return img, label

    def __len__(self):
        return len(self.data_info)

2.dataloder

torch.utils.data.DataLoder(
dataset, dataset類,決定數據從哪讀取以及如何讀取
batch_size=1, 批大小
shuffle=False, 每個epoch是否亂序
sampler,
batch_sampler,
num_workers=0,是否多進程讀取數據
collate_fn,
pin_memory
drop_last=False,當樣本數不能被batchsize整除時,是否捨棄最後一批數據
timeout,
worker_init_fn
multiprocessing_context
)
所有訓練樣本都已經輸入到模型中,稱爲一個epoch
一批樣本輸入到模型中,稱之爲一個iteration
批大小,決定一個epoch有多少個iteration

如:樣本總數:80,batchsize:8
1 epoch=10 iteration

這兩個我自己看的似懂非懂的樣子,以後再補充

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章