PyTorch 源碼(1) Dataset/Sampler/DataLoader


0. 前言


1. Dataset 相關源碼

  • 源碼位於 torch/utils/data/dataset.py
    • Dataset:定義了數據集的基本形式(通過下標獲取元素)。
    • IterableDataset:定義了iterable數據集的基本形式(通過迭代器獲取元素)。
    • TensorDataset:輸入若干個tensor,將每個tensor中對應元素組成爲元組,作爲數據集元素。
    • ConcatDataset:合併基本形式的數據集(通過下標獲取元素)。
    • ChainDataset:合併ierable數據集。
    • Subset:定義基本形式數據集(通過下標獲取元素)的子集。
    • random_split:將基本形式數據集(通過下標獲取元素)分爲若干子集。

1.1. Dataset

  • 定義了PyTorch中數據集基本形式,key-value形式,即通過key獲取對應的樣本數據。key可以是數字,也可以是字符串。
  • 定義了兩個基本函數,__getitem__實現key-value結構,__add__定義兩個數據集疊加的操作。
class Dataset(object):
    def __getitem__(self, index):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

1.2. IterableDataset

  • Iterable的數據集,其實就是增加了一個 __iter__ 函數
  • 注意,數據集合並後方法__add__的實現有所變化。
  • 註釋中給出了分佈式訓練時的樣例,
class IterableDataset(Dataset):
    def __iter__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ChainDataset([self, other])

1.3. TensorDataset

  • 應用場景:有若干個tensor,每個樣本是從每個tensor中獲取一個元素構成的。
    • 例如,有一個image name list和一個label list,那每個樣本就是image list中的一個元素和label list中的一個元素組成。
  • 實現的功能是:
    • 輸入一組tensor,要求每個tensor的第一維shape的數值是一樣的。
    • 數據集大小就時tensor的第一維shape數值。
    • 每個樣本就是tensor列表中分別獲取一個元素,由這些元素組成的元組。
class TensorDataset(Dataset):
    def __init__(self, *tensors):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors

    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors)

    def __len__(self):
        return self.tensors[0].size(0)

1.4. ConcatDataset

  • 作用:多個IterableDataset數據集的合併。
  • 實現原理
    • 底層各個數據庫保存在一個list中。
    • 記錄一個cumulative_sizes列表,用於保存每個dataset有多少元素。
    • 在通過idx獲取元素的時候,通過cumulative_sizes判斷是第幾個dataset的第幾個元素。
class ConcatDataset(Dataset):
    @staticmethod
    def cumsum(sequence):
        r, s = [], 0
        for e in sequence:
            l = len(e)
            r.append(l + s)
            s += l
        return r

    def __init__(self, datasets):
        super(ConcatDataset, self).__init__()
        assert len(datasets) > 0, 'datasets should not be an empty iterable'
        self.datasets = list(datasets)
        for d in self.datasets:
            assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
        self.cumulative_sizes = self.cumsum(self.datasets)

    def __len__(self):
        return self.cumulative_sizes[-1]

    def __getitem__(self, idx):
        if idx < 0:
            if -idx > len(self):
                raise ValueError("absolute value of index should not exceed dataset length")
            idx = len(self) + idx
        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        if dataset_idx == 0:
            sample_idx = idx
        else:
            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
        # 通過數據集id和數據集中元素id來獲取
        return self.datasets[dataset_idx][sample_idx]

    @property
    def cummulative_sizes(self):
        warnings.warn("cummulative_sizes attribute is renamed to "
                      "cumulative_sizes", DeprecationWarning, stacklevel=2)
        return self.cumulative_sizes

1.5. ChainDataset

  • 作用:多個IterableDataset的合併。
  • 實現原理
    • 在定義對象的時候,其實就是保存了一下輸入的datasets,其他啥都沒做。所以文檔中說,定義該對象是on-the-fly,非常高效。
    • 實現過程也非常容易,其實就是通過迭代器、返回迭代器(即每個IterableDataset對象都是一個迭代器)。
class ChainDataset(IterableDataset):
    def __init__(self, datasets):
        super(ChainDataset, self).__init__()
        self.datasets = datasets

    def __iter__(self):
        for d in self.datasets:
            assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
            for x in d:
                yield x

    def __len__(self):
        total = 0
        for d in self.datasets:
            assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
            total += len(d)
        return total

1.6. Subset

  • 作用:構建數據集的子集。
  • 實現原理
    • 輸入一個數據集(成爲raw dataset)和一個下標集合(稱爲raw index),選擇數據集中這些指定下標的元素,將這些結果作爲一個子集(稱爲sub dataset)。
    • 在代碼實現中,其實就是保存了下標集合(raw index)和原始數據集對象(raw dataset),獲取子集對象的時候就是通過子集下標(subset index)獲取raw index中對應位置的id,然後通過獲取的raw index獲取raw dataset中的元素。
class Subset(Dataset):
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]

    def __len__(self):
        return len(self.indices)

1.7. 方法 random_split

  • 作用:將輸入的數據集,分爲指定長度的若干個子集。
def random_split(dataset, lengths):
    if sum(lengths) != len(dataset):
        raise ValueError("Sum of input lengths does not equal the length of the input dataset!")

    indices = randperm(sum(lengths)).tolist()
    return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]

2. Sampler 源碼

  • 源碼位於 torch/utils/data/sampler.py
    • Sampler:所有Sampler的父類。
    • SequentialSampler:順序依次獲取下標。
    • RandomSampler:亂序獲取下標。
    • SubsetRandomSampler:某個子集內亂序獲取下標。
    • WeightedRandomSampler:爲每個樣本設置權重,權重大表示獲取概率高。
    • BatchSampler:即將若干個樣本形成一個batch。

2.1. Sampler

  • 作用:定義了 Sampler 的基本形式,作用是定義從數據集中獲取元素的方法。
  • 實現原理
    • 本質就是定義了構造器和集成了魔法方法__iter__
    • 構造器中包含一個數據來源。
    • 集成魔法方法 __iter__ 是爲了使得 Sampler 成爲一個迭代器。
    • 註釋很清楚的寫明瞭 Sampler 的作用:providing a way to iterate over indices of dataset elements。
  • 註解中說明了,Sampler 其實是需要一個 __len__ 方法,但如果要定義一個方法會存在一些BUG,最好的方法就是不定義。
    • 其實我沒看懂。
class Sampler(object):
    r"""Base class for all Samplers.

    Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
    way to iterate over indices of dataset elements, and a :meth:`__len__` method
    that returns the length of the returned iterators.

    .. note:: The :meth:`__len__` method isn't strictly required by
              :class:`~torch.utils.data.DataLoader`, but is expected in any
              calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
    """

    def __init__(self, data_source):
        pass

    def __iter__(self):
        raise NotImplementedError

    # NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    #
    # Many times we have an abstract class representing a collection/iterable of
    # data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally
    # implementing a `__len__` method. In such cases, we must make sure to not
    # provide a default implementation, because both straightforward default
    # implementations have their issues:
    #
    #   + `return NotImplemented`:
    #     Calling `len(subclass_instance)` raises:
    #       TypeError: 'NotImplementedType' object cannot be interpreted as an integer
    #
    #   + `raise NotImplementedError()`:
    #     This prevents triggering some fallback behavior. E.g., the built-in
    #     `list(X)` tries to call `len(X)` first, and executes a different code
    #     path if the method is not found or `NotImplemented` is returned, while
    #     raising an `NotImplementedError` will propagate and and make the call
    #     fail where it could have use `__iter__` to complete the call.
    #
    # Thus, the only two sensible things to do are
    #
    #   + **not** provide a default `__len__`.
    #
    #   + raise a `TypeError` instead, which is what Python uses when users call
    #     a method that is not defined on an object.
    #     (@ssnl verifies that this works on at least Python 3.7.)

2.2. SequentialSampler

  • 作用:順序獲取樣本,總是使用相同的順序。
  • 實現思路:沒啥好多說的,就是用了 iter(range(len(data_source))) 作爲迭代器。
class SequentialSampler(Sampler):
    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(range(len(self.data_source)))

    def __len__(self):
        return len(self.data_source)

2.3. RandomSampler

  • 作用:隨機獲取樣本,通過設置 replacement 可以控制是否會重複獲取元素。另外,當 replacement 爲True時,可通過設置 num_samples 來設置獲取樣本的總數。
  • 實現思路
    • replacement 爲False時,使用 torch.randperm(len(dataset)) 來獲取index,即每次迭代都獲取所有樣本一次。
    • replacement 爲True時,使用 torch.randint(high=n, size=(self.num_samples,) 來獲取下標。
class RandomSampler(Sampler):
    def __init__(self, data_source, replacement=False, num_samples=None):
        self.data_source = data_source
        self.replacement = replacement
        self._num_samples = num_samples

        if not isinstance(self.replacement, bool):
            raise ValueError("replacement should be a boolean value, but got "
                             "replacement={}".format(self.replacement))

        if self._num_samples is not None and not replacement:
            raise ValueError("With replacement=False, num_samples should not be specified, "
                             "since a random permute will be performed.")

        if not isinstance(self.num_samples, int) or self.num_samples <= 0:
            raise ValueError("num_samples should be a positive integer "
                             "value, but got num_samples={}".format(self.num_samples))

    @property
    def num_samples(self):
        # dataset size might change at runtime
        if self._num_samples is None:
            return len(self.data_source)
        return self._num_samples

    def __iter__(self):
        n = len(self.data_source)
        if self.replacement:
            return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
        return iter(torch.randperm(n).tolist())

    def __len__(self):
        return self.num_samples

2.4. SubsetRandomSampler

  • 作用:其實就是打亂給定下標的順序。
class SubsetRandomSampler(Sampler):
    def __init__(self, indices):
        self.indices = indices

    def __iter__(self):
        return (self.indices[i] for i in torch.randperm(len(self.indices)))

    def __len__(self):
        return len(self.indices)

2.5. WeightedRandomSampler

  • 作用:給定每個樣本選取的概率,然後根據概率隨機獲取樣本。
  • 實現思路
    • 通過 weights 設定樣本概率,要求長度就是樣本的數量。
    • 通過 num_samples 設置獲取樣本的數量,通過 replacement 設置是否可重複獲取樣本。
    • 最終實現就是通過 torch.multinomial(self.weights, self.num_samples, self.replacement)
class WeightedRandomSampler(Sampler):
    r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).

    Args:
        weights (sequence)   : a sequence of weights, not necessary summing up to one
        num_samples (int): number of samples to draw
        replacement (bool): if ``True``, samples are drawn with replacement.
            If not, they are drawn without replacement, which means that when a
            sample index is drawn for a row, it cannot be drawn again for that row.

    Example:
        >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
        [0, 0, 0, 1, 0]
        >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
        [0, 1, 4, 3, 2]
    """

    def __init__(self, weights, num_samples, replacement=True):
        if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \
                num_samples <= 0:
            raise ValueError("num_samples should be a positive integer "
                             "value, but got num_samples={}".format(num_samples))
        if not isinstance(replacement, bool):
            raise ValueError("replacement should be a boolean value, but got "
                             "replacement={}".format(replacement))
        self.weights = torch.as_tensor(weights, dtype=torch.double)
        self.num_samples = num_samples
        self.replacement = replacement

    def __iter__(self):
        return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())

    def __len__(self):
        return self.num_samples

2.6. BatchSampler

  • 作用:用於包裝其他 Sampler 對象,從而獲得一個小批次樣本。
  • 實現思路:
    • 構造器中輸入一個 Sampler 對象,以及batch size大小。另外,通過 drop_last 可以控制是否無視最後一個樣本數小於 batch size 的小批次數據。
    • 構造小批次其實就是構建一個長度爲 batch size 的list,每個元素就是輸入 Sampler 獲取的下標。
class BatchSampler(Sampler):
    r"""Wraps another sampler to yield a mini-batch of indices.

    Args:
        sampler (Sampler): Base sampler.
        batch_size (int): Size of mini-batch.
        drop_last (bool): If ``True``, the sampler will drop the last batch if
            its size would be less than ``batch_size``

    Example:
        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
    """

    def __init__(self, sampler, batch_size, drop_last):
        if not isinstance(sampler, Sampler):
            raise ValueError("sampler should be an instance of "
                             "torch.utils.data.Sampler, but got sampler={}"
                             .format(sampler))
        if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
                batch_size <= 0:
            raise ValueError("batch_size should be a positive integer value, "
                             "but got batch_size={}".format(batch_size))
        if not isinstance(drop_last, bool):
            raise ValueError("drop_last should be a boolean value, but got "
                             "drop_last={}".format(drop_last))
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last

    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

    def __len__(self):
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size

3. DataLoader

  • DataLoader 的定位:
    • 輸入dataset以及sampler對象,獲得一個可迭代(iterable)對象。
    • 提供基本功能 batch 與 shuffle。
    • 提供更快的獲取數據的功能:如 num_workers 設置多線程讀取數據,pin_memory 鎖頁內存功能(將數據轉換到GPU中會更快一些)。
    • 不提供數據增強功能。
  • 實現原理:
    • 在構造器中根據輸入參數構建 dataset/num_workers/pin_memory/timeout/batch_size/drop_last/sampler/batch_sampler/ 等參數。
    • 在構建迭代器時,調用 __iter__ 方法,若沒有設置num_workers則使用但進程迭代器 _SingleProcessDataLoaderIter,否則就使用多進程迭代器 _MultiProcessingDataLoaderIter

3.1. 初始化規則

  • 能夠輸入的參數包括:
    • datasettorch.utils.data.Dataset 對象,分爲 map-style(即普通Dataset) 和 iterable-style (即IterableDataset對象)兩類。
    • batch_size=1
    • shuffle=False
    • samplertorch.utils.data.sampler.Sampler 對象。
    • batch_sampler:與sampler類似,但返回一系列index。
    • num_workers=0:設置多線程獲取數據
    • collate_fn:將一組sample轉換爲Tensor。
    • pin_memory=False:是否使用所頁內存。
    • drop_last=False:是否忽略最後一個sample數量不足batch size的batch。
    • timeout=0:表示從workers中獲取一個batch所用時間的限制
    • worker_init_fn:Seeding後/data loading前運行的函數,以worker編號作爲輸入。
    • multiprocessing_context
  • 基本規則:
    • num_workerstimeout 的數值不能是負數。
    • 如果輸入的dataset是 IterableDataset 對象,則shuffle必須是False,samplerbatch_sampler 必須是 None
    • samplershuffle 不能同時指定:我猜意思是,如果自定義sampler了,那如果需要shuffle也是在sampler中實現了。
    • batch_sampler 不能同時和 batch_size/shuffle/sampler/drop_last 中任意一個元素同時指定。
    • 如果 batch_size is None,則不能設置 shuffle/drop_last
  • 設置的成員變量
    • dataset:輸入變量直接賦值。
    • num_workers:輸入變量直接賦值。
    • pin_memory:輸入變量直接賦值。
    • timeout:輸入變量直接賦值。
    • worker_init_fn:輸入變量直接賦值。
    • multiprocesssing_context:輸入變量直接賦值。
    • _dataset_kinddataset 的類型,有 _DatasetKind.Map_DatasetKind.Iterable 兩個類型。
    • batch_size:在滿足上述基本規則的情況下,如果設置了 batch_sampler,則 batch_size 賦值爲 None,其他情況下就是輸入變量直接賦值。
    • drop_last:輸入變量直接賦值。
    • sampler:如果輸入變量不爲None則直接賦值;如果是iterable dataset,就創建 _InfiniteConstantSampler 實例(該類就定義在dataloader.py中);如果設置了shuffle就創建 RandomSampler 實例;一般默認就創建 SequentialSampler 實例。
    • batch_sampler:如果輸入變量不爲None就直接賦值;否則就創建BatchSampler實例。
    • _auto_collation:其實就是看batch_sampler是否設置,設置了就是True,否則就是False。
    • collate_fn:如果輸入變量不爲None就直接賦值;其他情況中,如果設置了batch_sampler就使用 torch.utils.data._utils.collate.default_collate,沒有設置batch_sampler就使用 torch.utils.data._utils.collate.default_convert
    • __initialized:設置爲True。

3.2. 構建迭代器

  • DataLoader 的主要功能之一就是將 Dataset 對象構建爲可迭代對象,所以DataLoader通過 __iter__ 構建迭代器。
  • 迭代器分類:
    • 單進程迭代器 _SingleProcessDataLoaderIter
    • 多進程迭代器 _MultiProcessingDataLoaderIter
  • 上述兩個迭代器有一個功能的父類 _BaseDataLoaderIter。代碼不復雜,主要就是先了幾個功能
    • 初始化了一些列成員變量。
    • 通過 iter(self._index_sampler) 設置 _sampler_iter 變量。
    • 設置需要實現的 __next__ 方法。
    • 通過調用 next(self._sampler_iter) 來獲取下一個下標。
    • 通過 len(self._index_sampler) 來獲取數據長度。
class _BaseDataLoaderIter(object):
    def __init__(self, loader):
        self._dataset = loader.dataset
        self._dataset_kind = loader._dataset_kind
        self._auto_collation = loader._auto_collation
        self._drop_last = loader.drop_last
        self._index_sampler = loader._index_sampler
        self._num_workers = loader.num_workers
        self._pin_memory = loader.pin_memory and torch.cuda.is_available()
        self._timeout = loader.timeout
        self._collate_fn = loader.collate_fn
        self._sampler_iter = iter(self._index_sampler)
        self._base_seed = torch.empty((), dtype=torch.int64).random_().item()

    def __iter__(self):
        return self

    def _next_index(self):
        return next(self._sampler_iter)  # may raise StopIteration

    def __next__(self):
        raise NotImplementedError

    def __len__(self):
        return len(self._index_sampler)

    def __getstate__(self):
        # TODO: add limited pickling support for sharing an iterator
        # across multiple threads for HOGWILD.
        # Probably the best way to do this is by moving the sample pushing
        # to a separate thread and then just sharing the data queue
        # but signalling the end is tricky without a non-blocking API
        raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
  • 單線程迭代器 _SingleProcessDataLoaderIter
    • 主要實現思路構建 dataset fetcher,先獲取 index,在通過 fetcher 與 index 獲取數據,經過pin memory得到最終數據。
    • fetcher 的源碼在 torch.utils.data._utils.fecth.py 中,主要作用就是通過 dataset 對象獲取若干樣本,然後通過 collate_fn 方法轉換爲 tensor。
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0

        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

    def __next__(self):
        index = self._next_index()  # may raise StopIteration
        data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data

    next = __next__  # Python 2 compatibility
  • 多線程迭代器 _MultiProcessingDataLoaderIter

    • 這個類光註釋就有280行……
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章