1. torch.utils.data.DataLoader類:
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
- 作用:是加載數據的核心,返回可迭代的數據。
- PyTorch中數據讀取的一個重要接口是torch.utils.data.DataLoader,該接口定義在dataloader.py腳本中,只要是用PyTorch來訓練模型基本都會用到該接口。
- 該接口主要用來將自定義的數據讀取接口的輸出或者PyTorch已有的數據讀取接口的輸入按照batch size封裝成Tensor,後續只需要再包裝成V;ariable即可作爲模型的輸入,因此該接口有點承上啓下的作用,比較重要。
- 參數:
* dataset (Dataset): 加載數據的數據集
* batch_size (int, optional): 每批加載多少個樣本
* shuffle (bool, optional): 設置爲“真”時,在每個epoch對數據打亂.(默認:False)
* sampler (Sampler, optional): 定義從數據集中提取樣本的策略,返回一個樣本
* batch_sampler (Sampler, optional): like sampler, but returns a batch of indices at a time 返回一批樣本. 與atch_size, shuffle, sampler和 drop_last互斥.
* num_workers (int, optional): 用於加載數據的子進程數。0表示數據將在主進程中加載。(默認:0)
* collate_fn (callable, optional): 合併樣本列表以形成一個 mini-batch. # callable可調用對象
* pin_memory (bool, optional): 如果爲 True, 數據加載器會將張量複製到 CUDA 固定內存中,然後再返回它們.
* drop_last (bool, optional): 設定爲 True 如果數據集大小不能被批量大小整除的時候, 將丟掉最後一個不完整的batch,(默認:False).
* timeout (numeric, optional): 如果爲正值,則爲從工作人員收集批次的超時值。應始終是非負的。(默認:0)
* worker_init_fn (callable, optional): If not None, this will be called on each worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading. (default: None).
# num_workers 爲0的話表示:數據導入在主進程中進行;其他大於0的數表示:通過多個進程來導入數據,可以加快數據導入速度。
self.num_workers等於0的情況,也就是不採用多進程進行數據讀取。先通過indices = next(self.sample_iter)獲取長度爲batch size的列表:indices,這個列表的每個值表示一個batch中每個數據的index,每執行一次next操作都會讀取一批長度爲batch size的indices列表。 然後通過self.collate_fn函數將batch size個tuple(每個tuple長度爲2,其中第一個值是數據,Tensor類型,第二個值是標籤,int類型)封裝成一個list,這個list長度爲2,兩個值都是Tensor,一個是batch size個數據組成的FloatTensor,另一個是batch size個標籤組成的LongTensor。所以簡單講self.collate_fn函數就是將batch size個分散的Tensor封裝成一個Tensor。
self.num_workers語句是針對多進程或單進程的情況進行初始化,如果不是設置爲多進程讀取數據,那麼就不需要這些初始化操作,後面會介紹單進程數據讀取。
通過multiprocessing.SimpleQueue()類創建了一個簡單的隊列對象。multiprocessing.Process類就是構造進程的類,這裏根據設定的進程數來啓動,然後賦值給self.workers。接下來的一個for循環就通過調用start方法依次啓動self.workers中的進程。
如果設置爲多進程讀取數據,那麼就會採用隊列的方式來讀,如果不是採用多進程來讀取數據,那就採用普通方式來讀。
2. DataLoader類源代碼:
先看看__init__
中的幾個重要的輸入,也就是參數,參數上面已經解釋過了。
在__init__中,RandomSampler類表示隨機採樣且不重複,所以起到的就是shuffle的作用。
BatchSampler類則是把batch size個RandomSampler類對象封裝成一個,這樣就實現了隨機選取一個batch的目的。這兩個採樣類都;是定義在sampler.py腳本中,地址:https://github.com/pytorch/pytorch/blob/master/torch/utils/data/sampler.py。以上這些都是初始化的時候進行的。當代碼運行到要從torch.utils.data.DataLoader類生成的對象中取數據的時候,比如:
train_data=torch.utils.data.DataLoader(...)
for i, (input, target) in enumerate(train_data):
...
就會調用DataLoader類的__iter__方法,__iter__方法就一行代碼:return DataLoaderIter(self),輸入正是DataLoader類的屬性。因此當調用__iter__方法的時候就牽扯到另外一個類:DataLoaderIter,接下來介紹。
class DataLoader(object):
"""
Data loader. Combines a dataset and a sampler, and provides
single- or multi-process iterators over the dataset.
Arguments:
dataset (Dataset): dataset from which to load the data.
batch_size (int, optional): how many samples per batch to load
(default: 1).
shuffle (bool, optional): set to ``True`` to have the data reshuffled
at every epoch (default: False).
sampler (Sampler, optional): defines the strategy to draw samples from
the dataset. If specified, ``shuffle`` must be False.
batch_sampler (Sampler, optional): like sampler, but returns a batch of
indices at a time. Mutually exclusive with batch_size, shuffle,
sampler, and drop_last.
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means that the data will be loaded in the main process.
(default: 0)
collate_fn (callable, optional): merges a list of samples to form a mini-batch.
pin_memory (bool, optional): If ``True``, the data loader will copy tensors
into CUDA pinned memory before returning them.
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch size, then the last batch
will be smaller. (default: False)
timeout (numeric, optional): if positive, the timeout value for collecting a batch
from workers. Should always be non-negative. (default: 0)
worker_init_fn (callable, optional): If not None, this will be called on each
worker subprocess with the worker id as input, after seeding and before data
loading. (default: None)
"""
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
self.collate_fn = collate_fn
self.pin_memory = pin_memory
self.drop_last = drop_last
self.timeout = timeout
self.worker_init_fn = worker_init_fn
if timeout < 0:
raise ValueError('timeout option should be non-negative')
if batch_sampler is not None:
if batch_size > 1 or shuffle or sampler is not None or drop_last:
raise ValueError('batch_sampler is mutually exclusive with '
'batch_size, shuffle, sampler, and drop_last')
if sampler is not None and shuffle:
raise ValueError('sampler is mutually exclusive with shuffle')
if self.num_workers < 0:
raise ValueError('num_workers cannot be negative; '
'use num_workers=0 to disable multiprocessing.')
if batch_sampler is None:
if sampler is None:
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.sampler = sampler
self.batch_sampler = batch_sampler
def __iter__(self):
return DataLoaderIter(self)
def __len__(self):
return len(self.batch_sampler)
原博客:https://blog.csdn.net/u014380165/article/details/79058479