文章目錄
0. 前言
1. Dataset 相關源碼
- 源碼位於
torch/utils/data/dataset.py
Dataset
:定義了數據集的基本形式(通過下標獲取元素)。IterableDataset
:定義了iterable數據集的基本形式(通過迭代器獲取元素)。TensorDataset
:輸入若干個tensor,將每個tensor中對應元素組成爲元組,作爲數據集元素。ConcatDataset
:合併基本形式的數據集(通過下標獲取元素)。ChainDataset
:合併ierable數據集。Subset
:定義基本形式數據集(通過下標獲取元素)的子集。random_split
:將基本形式數據集(通過下標獲取元素)分爲若干子集。
1.1. Dataset
- 定義了PyTorch中數據集基本形式,key-value形式,即通過key獲取對應的樣本數據。key可以是數字,也可以是字符串。
- 定義了兩個基本函數,
__getitem__
實現key-value結構,__add__
定義兩個數據集疊加的操作。
class Dataset(object):
def __getitem__(self, index):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
1.2. IterableDataset
- Iterable的數據集,其實就是增加了一個
__iter__
函數 - 注意,數據集合並後方法
__add__
的實現有所變化。 - 註釋中給出了分佈式訓練時的樣例,
class IterableDataset(Dataset):
def __iter__(self):
raise NotImplementedError
def __add__(self, other):
return ChainDataset([self, other])
1.3. TensorDataset
- 應用場景:有若干個tensor,每個樣本是從每個tensor中獲取一個元素構成的。
- 例如,有一個image name list和一個label list,那每個樣本就是image list中的一個元素和label list中的一個元素組成。
- 實現的功能是:
- 輸入一組tensor,要求每個tensor的第一維shape的數值是一樣的。
- 數據集大小就時tensor的第一維shape數值。
- 每個樣本就是tensor列表中分別獲取一個元素,由這些元素組成的元組。
class TensorDataset(Dataset):
def __init__(self, *tensors):
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
self.tensors = tensors
def __getitem__(self, index):
return tuple(tensor[index] for tensor in self.tensors)
def __len__(self):
return self.tensors[0].size(0)
1.4. ConcatDataset
- 作用:多個非
IterableDataset
數據集的合併。 - 實現原理
- 底層各個數據庫保存在一個list中。
- 記錄一個cumulative_sizes列表,用於保存每個dataset有多少元素。
- 在通過idx獲取元素的時候,通過cumulative_sizes判斷是第幾個dataset的第幾個元素。
class ConcatDataset(Dataset):
@staticmethod
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r
def __init__(self, datasets):
super(ConcatDataset, self).__init__()
assert len(datasets) > 0, 'datasets should not be an empty iterable'
self.datasets = list(datasets)
for d in self.datasets:
assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
self.cumulative_sizes = self.cumsum(self.datasets)
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError("absolute value of index should not exceed dataset length")
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
# 通過數據集id和數據集中元素id來獲取
return self.datasets[dataset_idx][sample_idx]
@property
def cummulative_sizes(self):
warnings.warn("cummulative_sizes attribute is renamed to "
"cumulative_sizes", DeprecationWarning, stacklevel=2)
return self.cumulative_sizes
1.5. ChainDataset
- 作用:多個
IterableDataset
的合併。 - 實現原理
- 在定義對象的時候,其實就是保存了一下輸入的datasets,其他啥都沒做。所以文檔中說,定義該對象是on-the-fly,非常高效。
- 實現過程也非常容易,其實就是通過迭代器、返回迭代器(即每個
IterableDataset
對象都是一個迭代器)。
class ChainDataset(IterableDataset):
def __init__(self, datasets):
super(ChainDataset, self).__init__()
self.datasets = datasets
def __iter__(self):
for d in self.datasets:
assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
for x in d:
yield x
def __len__(self):
total = 0
for d in self.datasets:
assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
total += len(d)
return total
1.6. Subset
- 作用:構建數據集的子集。
- 實現原理
- 輸入一個數據集(成爲raw dataset)和一個下標集合(稱爲raw index),選擇數據集中這些指定下標的元素,將這些結果作爲一個子集(稱爲sub dataset)。
- 在代碼實現中,其實就是保存了下標集合(raw index)和原始數據集對象(raw dataset),獲取子集對象的時候就是通過子集下標(subset index)獲取raw index中對應位置的id,然後通過獲取的raw index獲取raw dataset中的元素。
class Subset(Dataset):
def __init__(self, dataset, indices):
self.dataset = dataset
self.indices = indices
def __getitem__(self, idx):
return self.dataset[self.indices[idx]]
def __len__(self):
return len(self.indices)
1.7. 方法 random_split
- 作用:將輸入的數據集,分爲指定長度的若干個子集。
def random_split(dataset, lengths):
if sum(lengths) != len(dataset):
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
indices = randperm(sum(lengths)).tolist()
return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]
2. Sampler 源碼
- 源碼位於
torch/utils/data/sampler.py
Sampler
:所有Sampler的父類。SequentialSampler
:順序依次獲取下標。RandomSampler
:亂序獲取下標。SubsetRandomSampler
:某個子集內亂序獲取下標。WeightedRandomSampler
:爲每個樣本設置權重,權重大表示獲取概率高。BatchSampler
:即將若干個樣本形成一個batch。
2.1. Sampler
- 作用:定義了
Sampler
的基本形式,作用是定義從數據集中獲取元素的方法。 - 實現原理
- 本質就是定義了構造器和集成了魔法方法
__iter__
。 - 構造器中包含一個數據來源。
- 集成魔法方法
__iter__
是爲了使得Sampler
成爲一個迭代器。 - 註釋很清楚的寫明瞭
Sampler
的作用:providing a way to iterate over indices of dataset elements。
- 本質就是定義了構造器和集成了魔法方法
- 註解中說明了,
Sampler
其實是需要一個__len__
方法,但如果要定義一個方法會存在一些BUG,最好的方法就是不定義。- 其實我沒看懂。
class Sampler(object):
r"""Base class for all Samplers.
Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
way to iterate over indices of dataset elements, and a :meth:`__len__` method
that returns the length of the returned iterators.
.. note:: The :meth:`__len__` method isn't strictly required by
:class:`~torch.utils.data.DataLoader`, but is expected in any
calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
"""
def __init__(self, data_source):
pass
def __iter__(self):
raise NotImplementedError
# NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
#
# Many times we have an abstract class representing a collection/iterable of
# data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally
# implementing a `__len__` method. In such cases, we must make sure to not
# provide a default implementation, because both straightforward default
# implementations have their issues:
#
# + `return NotImplemented`:
# Calling `len(subclass_instance)` raises:
# TypeError: 'NotImplementedType' object cannot be interpreted as an integer
#
# + `raise NotImplementedError()`:
# This prevents triggering some fallback behavior. E.g., the built-in
# `list(X)` tries to call `len(X)` first, and executes a different code
# path if the method is not found or `NotImplemented` is returned, while
# raising an `NotImplementedError` will propagate and and make the call
# fail where it could have use `__iter__` to complete the call.
#
# Thus, the only two sensible things to do are
#
# + **not** provide a default `__len__`.
#
# + raise a `TypeError` instead, which is what Python uses when users call
# a method that is not defined on an object.
# (@ssnl verifies that this works on at least Python 3.7.)
2.2. SequentialSampler
- 作用:順序獲取樣本,總是使用相同的順序。
- 實現思路:沒啥好多說的,就是用了
iter(range(len(data_source)))
作爲迭代器。
class SequentialSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
2.3. RandomSampler
- 作用:隨機獲取樣本,通過設置
replacement
可以控制是否會重複獲取元素。另外,當replacement
爲True時,可通過設置num_samples
來設置獲取樣本的總數。 - 實現思路
- 當
replacement
爲False時,使用torch.randperm(len(dataset))
來獲取index,即每次迭代都獲取所有樣本一次。 - 當
replacement
爲True時,使用torch.randint(high=n, size=(self.num_samples,)
來獲取下標。
- 當
class RandomSampler(Sampler):
def __init__(self, data_source, replacement=False, num_samples=None):
self.data_source = data_source
self.replacement = replacement
self._num_samples = num_samples
if not isinstance(self.replacement, bool):
raise ValueError("replacement should be a boolean value, but got "
"replacement={}".format(self.replacement))
if self._num_samples is not None and not replacement:
raise ValueError("With replacement=False, num_samples should not be specified, "
"since a random permute will be performed.")
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(self.num_samples))
@property
def num_samples(self):
# dataset size might change at runtime
if self._num_samples is None:
return len(self.data_source)
return self._num_samples
def __iter__(self):
n = len(self.data_source)
if self.replacement:
return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
return iter(torch.randperm(n).tolist())
def __len__(self):
return self.num_samples
2.4. SubsetRandomSampler
- 作用:其實就是打亂給定下標的順序。
class SubsetRandomSampler(Sampler):
def __init__(self, indices):
self.indices = indices
def __iter__(self):
return (self.indices[i] for i in torch.randperm(len(self.indices)))
def __len__(self):
return len(self.indices)
2.5. WeightedRandomSampler
- 作用:給定每個樣本選取的概率,然後根據概率隨機獲取樣本。
- 實現思路
- 通過
weights
設定樣本概率,要求長度就是樣本的數量。 - 通過
num_samples
設置獲取樣本的數量,通過replacement
設置是否可重複獲取樣本。 - 最終實現就是通過
torch.multinomial(self.weights, self.num_samples, self.replacement)
。
- 通過
class WeightedRandomSampler(Sampler):
r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
Args:
weights (sequence) : a sequence of weights, not necessary summing up to one
num_samples (int): number of samples to draw
replacement (bool): if ``True``, samples are drawn with replacement.
If not, they are drawn without replacement, which means that when a
sample index is drawn for a row, it cannot be drawn again for that row.
Example:
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
[0, 0, 0, 1, 0]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
[0, 1, 4, 3, 2]
"""
def __init__(self, weights, num_samples, replacement=True):
if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \
num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(num_samples))
if not isinstance(replacement, bool):
raise ValueError("replacement should be a boolean value, but got "
"replacement={}".format(replacement))
self.weights = torch.as_tensor(weights, dtype=torch.double)
self.num_samples = num_samples
self.replacement = replacement
def __iter__(self):
return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())
def __len__(self):
return self.num_samples
2.6. BatchSampler
- 作用:用於包裝其他
Sampler
對象,從而獲得一個小批次樣本。 - 實現思路:
- 構造器中輸入一個
Sampler
對象,以及batch size大小。另外,通過drop_last
可以控制是否無視最後一個樣本數小於 batch size 的小批次數據。 - 構造小批次其實就是構建一個長度爲 batch size 的list,每個元素就是輸入
Sampler
獲取的下標。
- 構造器中輸入一個
class BatchSampler(Sampler):
r"""Wraps another sampler to yield a mini-batch of indices.
Args:
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``
Example:
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
"""
def __init__(self, sampler, batch_size, drop_last):
if not isinstance(sampler, Sampler):
raise ValueError("sampler should be an instance of "
"torch.utils.data.Sampler, but got sampler={}"
.format(sampler))
if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, "
"but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
raise ValueError("drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last))
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
3. DataLoader
- DataLoader 的定位:
- 輸入dataset以及sampler對象,獲得一個可迭代(iterable)對象。
- 提供基本功能 batch 與 shuffle。
- 提供更快的獲取數據的功能:如
num_workers
設置多線程讀取數據,pin_memory
鎖頁內存功能(將數據轉換到GPU中會更快一些)。 - 不提供數據增強功能。
- 實現原理:
- 在構造器中根據輸入參數構建
dataset/num_workers/pin_memory/timeout/batch_size/drop_last/sampler/batch_sampler/
等參數。 - 在構建迭代器時,調用
__iter__
方法,若沒有設置num_workers
則使用但進程迭代器_SingleProcessDataLoaderIter
,否則就使用多進程迭代器_MultiProcessingDataLoaderIter
。
- 在構造器中根據輸入參數構建
3.1. 初始化規則
- 能夠輸入的參數包括:
dataset
:torch.utils.data.Dataset
對象,分爲 map-style(即普通Dataset) 和 iterable-style (即IterableDataset
對象)兩類。batch_size=1
shuffle=False
sampler
:torch.utils.data.sampler.Sampler
對象。batch_sampler
:與sampler類似,但返回一系列index。num_workers=0
:設置多線程獲取數據collate_fn
:將一組sample轉換爲Tensor。pin_memory=False
:是否使用所頁內存。drop_last=False
:是否忽略最後一個sample數量不足batch size的batch。timeout=0
:表示從workers中獲取一個batch所用時間的限制worker_init_fn
:Seeding後/data loading前運行的函數,以worker編號作爲輸入。multiprocessing_context
- 基本規則:
num_workers
和timeout
的數值不能是負數。- 如果輸入的dataset是
IterableDataset
對象,則shuffle必須是False,sampler
和batch_sampler
必須是None
。 sampler
和shuffle
不能同時指定:我猜意思是,如果自定義sampler了,那如果需要shuffle也是在sampler中實現了。batch_sampler
不能同時和batch_size/shuffle/sampler/drop_last
中任意一個元素同時指定。- 如果
batch_size is None
,則不能設置shuffle/drop_last
。
- 設置的成員變量
dataset
:輸入變量直接賦值。num_workers
:輸入變量直接賦值。pin_memory
:輸入變量直接賦值。timeout
:輸入變量直接賦值。worker_init_fn
:輸入變量直接賦值。multiprocesssing_context
:輸入變量直接賦值。_dataset_kind
:dataset
的類型,有_DatasetKind.Map
和_DatasetKind.Iterable
兩個類型。batch_size
:在滿足上述基本規則的情況下,如果設置了batch_sampler
,則batch_size
賦值爲None
,其他情況下就是輸入變量直接賦值。drop_last
:輸入變量直接賦值。sampler
:如果輸入變量不爲None
則直接賦值;如果是iterable dataset,就創建_InfiniteConstantSampler
實例(該類就定義在dataloader.py
中);如果設置了shuffle就創建RandomSampler
實例;一般默認就創建SequentialSampler
實例。batch_sampler
:如果輸入變量不爲None
就直接賦值;否則就創建BatchSampler
實例。_auto_collation
:其實就是看batch_sampler
是否設置,設置了就是True,否則就是False。collate_fn
:如果輸入變量不爲None
就直接賦值;其他情況中,如果設置了batch_sampler
就使用torch.utils.data._utils.collate.default_collate
,沒有設置batch_sampler
就使用torch.utils.data._utils.collate.default_convert
__initialized
:設置爲True。
3.2. 構建迭代器
DataLoader
的主要功能之一就是將Dataset
對象構建爲可迭代對象,所以DataLoader
通過__iter__
構建迭代器。- 迭代器分類:
- 單進程迭代器
_SingleProcessDataLoaderIter
- 多進程迭代器
_MultiProcessingDataLoaderIter
- 單進程迭代器
- 上述兩個迭代器有一個功能的父類
_BaseDataLoaderIter
。代碼不復雜,主要就是先了幾個功能- 初始化了一些列成員變量。
- 通過
iter(self._index_sampler)
設置_sampler_iter
變量。 - 設置需要實現的
__next__
方法。 - 通過調用
next(self._sampler_iter)
來獲取下一個下標。 - 通過
len(self._index_sampler)
來獲取數據長度。
class _BaseDataLoaderIter(object):
def __init__(self, loader):
self._dataset = loader.dataset
self._dataset_kind = loader._dataset_kind
self._auto_collation = loader._auto_collation
self._drop_last = loader.drop_last
self._index_sampler = loader._index_sampler
self._num_workers = loader.num_workers
self._pin_memory = loader.pin_memory and torch.cuda.is_available()
self._timeout = loader.timeout
self._collate_fn = loader.collate_fn
self._sampler_iter = iter(self._index_sampler)
self._base_seed = torch.empty((), dtype=torch.int64).random_().item()
def __iter__(self):
return self
def _next_index(self):
return next(self._sampler_iter) # may raise StopIteration
def __next__(self):
raise NotImplementedError
def __len__(self):
return len(self._index_sampler)
def __getstate__(self):
# TODO: add limited pickling support for sharing an iterator
# across multiple threads for HOGWILD.
# Probably the best way to do this is by moving the sample pushing
# to a separate thread and then just sharing the data queue
# but signalling the end is tricky without a non-blocking API
raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
- 單線程迭代器
_SingleProcessDataLoaderIter
- 主要實現思路構建 dataset fetcher,先獲取 index,在通過 fetcher 與 index 獲取數據,經過pin memory得到最終數據。
- fetcher 的源碼在
torch.utils.data._utils.fecth.py
中,主要作用就是通過dataset
對象獲取若干樣本,然後通過collate_fn
方法轉換爲 tensor。
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def __next__(self):
index = self._next_index() # may raise StopIteration
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
next = __next__ # Python 2 compatibility
-
多線程迭代器
_MultiProcessingDataLoaderIter
- 這個類光註釋就有280行……