Pytorch(二):Dataset和Dataloader的理解

目錄

1.可迭代對象,迭代器

2.數據集遍歷的一般化流程

3.Dataset

4.TensorDataset

5.Dataloader


1.可迭代對象,迭代器

首先,我們要明白python中的兩個概念:可迭代對象,迭代器。

  • 可迭代對象:
  1. 實現了__iter__方法,該方法返回一個迭代器對象。
  • 迭代器:

  1. 一個帶狀態的對象,內部持有一個狀態,該狀態用於記錄當前迭代所在的位置,以方便下次迭代的時候獲取正確的元素。

  2. 迭代器含有__iter__和__next__方法。當調用__iter__返回迭代器自身,當調用next()方法的時候,返回容器中的下一個值。

  3. 迭代器就像一個懶加載的工廠,等到有人需要的時候纔給它生成值返回,沒調用的時候就處於休眠狀態等待下一次調用。

  • iter()函數:

1. 用法一:iter(callable, sentinel)

不停的調用callable,直至其的返回值等於sentinel。其中的callable可以是函數,方法或實現了__call__方法的實例。

2. 用法二:iter(collection)

1)iter()直接調用可迭代對象的__iter__(),並把__iter__()的返回結果作爲自己的返回值,故該用法常被稱爲“創建迭代器”。

2)iter函數可以顯示調用,或當執行“for i in obj:”,Python解釋器會在第一次迭代時自動調用iter(obj),之後的迭代會調用迭代器的next方法,for語句會自動處理最後拋出的StopIteration異常。

3)但iter函數獲取不到 __iter__方法時,還會調用 __getitem__方法,參數是從0開始能獲取值就是可迭代的。

2.數據集遍歷的一般化流程

for i, data in enumerate(dataLoader):

enumerate(dataloader )會調用dataloader 的__iter__()方法, 產生了一個DataLoaderIter(迭代器),接着調用DataLoaderIter __next__()方法來得到batch data。 在__next__()方法方法中使用_next_index()方法獲得索引,接着通過dataset_fetcher的fetch()方法根據index調用dataset的__getitem__()方法, 然後用collate_fn來把它們打包成batch。當數據讀完後, __next__()拋出一個StopIteration異常, for循環結束, dataloader 失效.

3.Dataset

torch.utils.data.Dataset是代表這一數據的抽象類(也就是基類)。我們可以通過繼承重寫這個抽象類實現自己的數據類,只需要定義__len____getitem__這個兩個函數

如果在類中定義了__getitem__()方法,那麼實例對象(假設爲P)就可以這樣P[key]取值。當實例對象做P[key]操作時,就會調用類中的__getitem__()方法。

class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])
  • 通過複寫 __getitem__ 方法可以通過索引index來訪問數據,能夠同時返回數據對應的標籤(label),這裏的數據和標籤都爲tensor類型 
  • 通過複寫 __len__ 方法來獲取數據的個數。 

 比如:

class MyDataset(Dataset): 
    """ my dataset."""
    
    # Initialize your data, download, etc.
    def __init__(self):
        # 讀取csv文件中的數據
        xy = np.loadtxt('data-diabetes.csv', delimiter=',', dtype=np.float32) 
        self.len = xy.shape[0]
        # 除去最後一列爲數據位,存在x_data中
        self.x_data = torch.from_numpy(xy[:, 0:-1])
        # 最後一列爲標籤爲,存在y_data中
        self.y_data = torch.from_numpy(xy[:, [-1]])
        
    def __getitem__(self, index):
        # 根據索引返回數據和對應的標籤
        return self.x_data[index], self.y_data[index]
        
    def __len__(self): 
        # 返回文件數據的數目
        return self.len

4.TensorDataset

TensorDataset是Dataset的子類,已經複寫了__len__和__getitem__方法,只需傳入張量即可。

class TensorDataset(Dataset):
    """Dataset wrapping tensors.
    Each sample will be retrieved by indexing tensors along the first dimension.
    Arguments:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """
    def __init__(self, *tensors):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors
 
    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors)
 
    def __len__(self):
        return self.tensors[0].size(0)
  • 比如: 

可以看出我們把X和Y通過Data.TensorDataset() 這個函數拼裝成了一個數據集,數據集的類型是【TensorDataset】

import torch
import torch.utils.data as Data

BATCH_SIZE = 5

x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)

torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
    dataset=torch_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
)

for epoch in range(3):
    for step, (batch_x, batch_y) in enumerate(loader):
        print('Epoch: ', epoch, '| Step: ', step, '| batch x: ', batch_x.numpy(), '| batch y: ', batch_y.numpy())

5.Dataloader

DataLoader是Pytorch中用來處理模型輸入數據的一個工具類。組合了數據集(dataset) + 採樣器(sampler),並在數據集上提供單線程或多線程(num_workers )可迭代對象

  • 基本概念:
  1. epoch:      所有的訓練樣本輸入到模型中稱爲一個epoch;
  2. iteration:  一批樣本輸入到模型中,成爲一個Iteration;
  3. batchszie:批大小,決定一個epoch有多少個Iteration;

          迭代次數(iteration)=樣本總數(epoch)/批尺寸(batchszie)

  • 函數原型:
torch.utils.data.DataLoader(dataset, batch_size=1, 
    shuffle=False, sampler=None, 
    batch_sampler=None, num_workers=0, 
    collate_fn=None, pin_memory=False, 
    drop_last=False, timeout=0, 
    worker_init_fn=None, multiprocessing_context=None)
  • 參數:
  1. dataset (Dataset) – 決定數據從哪讀取或者從何讀取;

  2. batch_size (python:int, optional) – 批尺寸(每次訓練樣本個數,默認爲1)

  3. shuffle (bool, optional) –每一個 epoch是否爲亂序 (default: False).

  4. num_workers (python:int, optional) – 是否多進程讀取數據(默認爲0);

  5. drop_last (bool, optional) – 當樣本數不能被batchsize整除時,最後一批數據是否捨棄(default: False)

  6. pin_memory(bool, optional) - 如果爲True會將數據放置到GPU上去(默認爲false) 

參考:

https://blog.csdn.net/u014380165/article/details/78634829

https://blog.csdn.net/zw__chen/article/details/82806900

https://www.cnblogs.com/yongjieShi/p/10456802.html

https://www.cnblogs.com/ranjiewen/p/10128046.html

Python 子類繼承父類構造函數:https://www.runoob.com/w3cnote/python-extends-init.html

https://www.ziiai.com/blog/259

Python可迭代對象,迭代器,生成器的區別:https://blog.csdn.net/jinixin/article/details/72232604

完全理解Python迭代對象、迭代器、生成器:https://foofish.net/iterators-vs-generators.html

Pytorch中的數據加載藝術:http://studyai.com/article/11efc2bf

PyTorch 數據集(Dataset):https://geek-docs.com/pytorch/pytorch-tutorial/pytorch-dataset.html

https://www.cnblogs.com/marsggbo/p/11308889.html

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章