PyTorch源碼解讀之torch.utils.data.DataLoader

1. torch.utils.data.DataLoader類:

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
  • 作用:是加載數據的核心,返回可迭代的數據。
  • PyTorch中數據讀取的一個重要接口torch.utils.data.DataLoader,該接口定義在dataloader.py腳本中,只要是用PyTorch來訓練模型基本都會用到該接口。
  • 該接口主要用來將自定義的數據讀取接口的輸出或者PyTorch已有的數據讀取接口的輸入按照batch size封裝成Tensor,後續只需要再包裝成V;ariable即可作爲模型的輸入,因此該接口有點承上啓下的作用,比較重要。

  • 參數:
* dataset (Dataset): 加載數據的數據集
* batch_size (int, optional): 每批加載多少個樣本
* shuffle (bool, optional): 設置爲“真”時,在每個epoch對數據打亂.(默認:False)
* sampler (Sampler, optional): 定義從數據集中提取樣本的策略,返回一個樣本
* batch_sampler (Sampler, optional): like sampler, but returns a batch of indices at a time 返回一批樣本. 與atch_size, shuffle, sampler和 drop_last互斥.
* num_workers (int, optional): 用於加載數據的子進程數。0表示數據將在主進程中加載​​。(默認:0)
* collate_fn (callable, optional): 合併樣本列表以形成一個 mini-batch.  # callable可調用對象
* pin_memory (bool, optional): 如果爲 True, 數據加載器會將張量複製到 CUDA 固定內存中,然後再返回它們.
* drop_last (bool, optional): 設定爲 True 如果數據集大小不能被批量大小整除的時候, 將丟掉最後一個不完整的batch,(默認:False).
* timeout (numeric, optional): 如果爲正值,則爲從工作人員收集批次的超時值。應始終是非負的。(默認:0)
* worker_init_fn (callable, optional): If not None, this will be called on each worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading. (default: None).

# num_workers  爲0的話表示:數據導入在主進程中進行;其他大於0的數表示:通過多個進程來導入數據,可以加快數據導入速度。

self.num_workers等於0的情況,也就是不採用多進程進行數據讀取。先通過indices = next(self.sample_iter)獲取長度爲batch size的列表:indices,這個列表的每個值表示一個batch中每個數據的index,每執行一次next操作都會讀取一批長度爲batch size的indices列表。             然後通過self.collate_fn函數將batch size個tuple(每個tuple長度爲2,其中第一個值是數據,Tensor類型,第二個值是標籤,int類型)封裝成一個list,這個list長度爲2,兩個值都是Tensor一個是batch size個數據組成的FloatTensor另一個是batch size個標籤組成的LongTensor。所以簡單講self.collate_fn函數就是將batch size個分散的Tensor封裝成一個Tensor

self.num_workers語句是針對多進程或單進程的情況進行初始化,如果不是設置爲多進程讀取數據,那麼就不需要這些初始化操作,後面會介紹單進程數據讀取。

通過multiprocessing.SimpleQueue()類創建了一個簡單的隊列對象。multiprocessing.Process類就是構造進程的類,這裏根據設定的進程數來啓動,然後賦值給self.workers。接下來的一個for循環就通過調用start方法依次啓動self.workers中的進程。

如果設置爲多進程讀取數據,那麼就會採用隊列的方式來讀,如果不是採用多進程來讀取數據,那就採用普通方式來讀

 

2. DataLoader類源代碼:

先看看__init__中的幾個重要的輸入,也就是參數,參數上面已經解釋過了。

在__init__中,RandomSampler類表示隨機採樣且不重複,所以起到的就是shuffle的作用

BatchSampler類則是把batch size個RandomSampler類對象封裝成一個,這樣就實現了隨機選取一個batch的目的。這兩個採樣類都;是定義在sampler.py腳本中,地址:https://github.com/pytorch/pytorch/blob/master/torch/utils/data/sampler.py。以上這些都是初始化的時候進行的。當代碼運行到要從torch.utils.data.DataLoader類生成的對象中取數據的時候,比如: 
train_data=torch.utils.data.DataLoader(...) 
for i, (input, target) in enumerate(train_data): 
... 
就會調用DataLoader類的__iter__方法,__iter__方法就一行代碼:return DataLoaderIter(self),輸入正是DataLoader類的屬性。因此當調用__iter__方法的時候就牽扯到另外一個類:DataLoaderIter,接下來介紹。

class DataLoader(object):
"""
    Data loader. Combines a dataset and a sampler, and provides
    single- or multi-process iterators over the dataset.

    Arguments:
        dataset (Dataset): dataset from which to load the data.
        batch_size (int, optional): how many samples per batch to load
            (default: 1).
        shuffle (bool, optional): set to ``True`` to have the data reshuffled
            at every epoch (default: False).
        sampler (Sampler, optional): defines the strategy to draw samples from
            the dataset. If specified, ``shuffle`` must be False.
        batch_sampler (Sampler, optional): like sampler, but returns a batch of
            indices at a time. Mutually exclusive with batch_size, shuffle,
            sampler, and drop_last.
        num_workers (int, optional): how many subprocesses to use for data
            loading. 0 means that the data will be loaded in the main process.
            (default: 0)
        collate_fn (callable, optional): merges a list of samples to form a mini-batch.
        pin_memory (bool, optional): If ``True``, the data loader will copy tensors
            into CUDA pinned memory before returning them.
        drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
            if the dataset size is not divisible by the batch size. If ``False`` and
            the size of dataset is not divisible by the batch size, then the last batch
            will be smaller. (default: False)
        timeout (numeric, optional): if positive, the timeout value for collecting a batch
            from workers. Should always be non-negative. (default: 0)
        worker_init_fn (callable, optional): If not None, this will be called on each
            worker subprocess with the worker id as input, after seeding and before data
            loading. (default: None)
"""

    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
                 num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
                 timeout=0, worker_init_fn=None):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.collate_fn = collate_fn
        self.pin_memory = pin_memory
        self.drop_last = drop_last
        self.timeout = timeout
        self.worker_init_fn = worker_init_fn

        if timeout < 0:
            raise ValueError('timeout option should be non-negative')

        if batch_sampler is not None:
            if batch_size > 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler is mutually exclusive with '
                                 'batch_size, shuffle, sampler, and drop_last')

        if sampler is not None and shuffle:
            raise ValueError('sampler is mutually exclusive with shuffle')

        if self.num_workers < 0:
            raise ValueError('num_workers cannot be negative; '
                             'use num_workers=0 to disable multiprocessing.')

        if batch_sampler is None:
            if sampler is None:
                if shuffle:
                    sampler = RandomSampler(dataset)
                else:
                    sampler = SequentialSampler(dataset)
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.sampler = sampler
        self.batch_sampler = batch_sampler

    def __iter__(self):
        return DataLoaderIter(self)

    def __len__(self):
        return len(self.batch_sampler)

 

原博客:https://blog.csdn.net/u014380165/article/details/79058479

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章