1 Dataset-庖丁解牛之pytorch

1 數據庫基類

用來實現數據的大小和索引。
pytorch的Dataset類是一個抽象類,只先實現了三個魔法方法

class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

如描述所說,這是一個抽象類,其他數據庫類應該是它的子類,所有子類應該重載如下兩個函數

* __len__函數,用來提供數據庫的大小
* __getitem__函數,支持一個整形索引,重來獲取單個數據,範圍是__len__定義的,範圍是[0, len(self)]

2 數據庫的合併

其中Dataset.add函數返回一個ConcatDataset類,這個類實現了數據庫的合併,針對從基類DataSet派生類,ConcatDataset實現了不同源的數據庫整合,數據存儲在鏈表datasets中,通過累計長度,可以查詢不同的datasets,這個類的詳細描述如下:

class ConcatDataset(Dataset):
    """
    Dataset to concatenate multiple datasets.
    Purpose: useful to assemble different existing datasets, possibly
    large-scale datasets as the concatenation operation is done in an
    on-the-fly manner.

    Arguments:
        datasets (sequence): List of datasets to be concatenated
    """

    @staticmethod
    def cumsum(sequence):
        r, s = [], 0
        for e in sequence:
            l = len(e)
            r.append(l + s)
            s += l
        return r

    def __init__(self, datasets):
        super(ConcatDataset, self).__init__()
        assert len(datasets) > 0, 'datasets should not be an empty iterable'
        self.datasets = list(datasets)
        self.cumulative_sizes = self.cumsum(self.datasets)

    def __len__(self):
        return self.cumulative_sizes[-1]

    def __getitem__(self, idx):
        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        if dataset_idx == 0:
            sample_idx = idx
        else:
            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
        return self.datasets[dataset_idx][sample_idx]

    @property
    def cummulative_sizes(self):
        warnings.warn("cummulative_sizes attribute is renamed to "
                      "cumulative_sizes", DeprecationWarning, stacklevel=2)
        return self.cumulative_sizes

注意的是給定索引的時候,需要先判定是哪個數據集,然後判定數據集的索引,getitem函數使用了bitsect.bitsec_right查找數據庫索引,然後計算該數據庫的內部索引。

3 子數據庫Subset

ConcatDataset將不同數據集組成鏈表,在這個大數據集的基礎上,通過索引可以建立一個虛擬數據集,實現不同數據集的一個子集,如果通過隨機函數實現索引,可以混合所有數據集,Subset數據集的源碼如下:

import torch
from torch.utils.data import Dataset, ConcatDataset, Subset, random_split


class MyDataset(Dataset):
    def __init__(self, t=0, name="myDataset"):
        super(MyDataset, self).__init__()
        self.nums = []
        if t == 0:
            self.nums = [torch.randn(1).item() for _ in range(100)]
        elif t == 1:
            self.nums = list(range(230))
        elif t == 2:
            self.nums = torch.linspace(-1, 1, 250).data.numpy()
        self.name = name
        self.t = t

    def __getitem__(self, i):
        return self.nums[i]

    def __len__(self):
        return len(self.nums)


if __name__ == "__main__":
    ds0 = MyDataset(0, "type_0")
    ds1 = MyDataset(1, "type_1")
    ds2 = MyDataset(2, "type_2")
    ds = ds0 + ds1
    ds = ds + ds2
    print(ds.datasets[0].datasets[0].name,ds.datasets[0].datasets[1].name,ds.datasets[1].name)
    print(len(ds))
    dss = random_split(ds, [310, 270]) # 第二個參數是長度,累積和是數據集長度

此處要注意的是 ds0和ds1首先進行合併,形成一個ConcatDataset,然後和ds2合併,再形成一個ConcatDataset,因此ds的datasets長度爲2,第一個數據是ConcatDataset,第二個數據是MyDataset(2, "type_2")

4 Tensor向量化數據庫

內存數據需要轉爲Tensor才能使用,pytorch提供了TensorDataset類可以直接對Tensor數據進行數據庫封裝

class TensorDataset(Dataset):
    """Dataset wrapping tensors.

    Each sample will be retrieved by indexing tensors along the first dimension.

    Arguments:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """

    def __init__(self, *tensors):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors

    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors)

    def __len__(self):
        return self.tensors[0].size(0)

最後介紹一個對數據集進行子集切分的函數

def random_split(dataset, lengths):
    """
    Randomly split a dataset into non-overlapping new datasets of given lengths.

    Arguments:
        dataset (Dataset): Dataset to be split
        lengths (sequence): lengths of splits to be produced
    """
    if sum(lengths) != len(dataset):
        raise ValueError("Sum of input lengths does not equal the length of the input dataset!")

    indices = randperm(sum(lengths))
    return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]

數據集源碼解讀完畢了,雖然這是一個基類,但是提供了一個可迭代的思想,類似於道教的一分爲二,二生四,......,提供了數據索引,合併,tensor,子集的等基本功能。

torchvision.dataset可以使用的數據集

LSUN,  大規模場景理解
LSUNClass
ImageFolder, 圖片目錄的數據集
DatasetFolder 文件目錄的數據集
CocoCaptions,  微軟 MS COCO 相關的 Image Captioning 
CocoDetection MS COCO數據集目標檢測
CIFAR10,  該數據集共有60000張彩色圖像分類數據集
CIFAR100 數據集包含100小類,每小類包含600個圖像,其中有500個訓練圖像和100個測試圖像。100類被分組爲20個大類。每個圖像帶有1個小類的“fine”標籤和1個大類“coarse”標籤。
STL10 *   10個類:飛機,鳥,汽車,貓,鹿,狗,馬,猴子,船,卡車。*   圖像爲96x96像素,顏色。*   500個訓練圖像(10個預定義的摺疊),每個類800個測試圖像。
MNIST, MNIST數據集是一個手寫體數據集
EMNIST, 擴展手寫體數據集
FashionMNIST FashionMNIST 是一個替代 MNIST 手寫數字集[1] 的圖像數據集。 它是由 Zalando(一家德國的時尚科技公司)旗下的研究部門提供。其涵蓋了來自 10 種類別的共 7 萬個不同商品的正面圖片。
SVHN
PhotoTour
FakeData
SEMEION 圖像處理_Semeion Handwritten Digit Data Set(Semeion手寫體數字數據集)
Omniglot Omniglot是一個在線的語言文字百科,其內涵蓋了已知的全部書寫系統
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章