PyTorch 入門實戰(三)——Dataset和DataLoader

承接上一篇:PyTorch 入門實戰(二)——Variable


目錄

一、概念

二、Dataset的創建和使用

三、DataLoader的創建和使用

*四、將Dataset數據和標籤放在GPU上(會有bug)

五、Dataset和DataLoader總結


一、概念

1.torch.utils.data.dataset這樣的抽象類可以用來創建數據集。學過面向對象的應該清楚,抽象類不能實例化,因此我們需要構造這個抽象類的子類來創建數據集,並且我們還可以定義自己的繼承和重寫方法。

2.這其中最重要的就是__len____getitem__這兩個函數,前者給出數據集的大小,後者是用於查找數據和標籤

3.torch.utils.data.DataLoader是一個迭代器,方便我們去多線程地讀取數據,並且可以實現batch以及shuffle的讀取等。

二、Dataset的創建和使用

1.首先我們需要引入dataset這個抽象類,當然我們還需要引入Numpy:

import torch.utils.data.dataset as Dataset
import numpy as np

2.我們創建Dataset的一個子類:

(1)初始化,定義數據內容和標籤

#初始化,定義數據內容和標籤
def __init__(self, Data, Label):
    self.Data = Data
    self.Label = Label

(2)返回數據集大小

#返回數據集大小
def __len__(self):
    return len(self.Data)

(3)得到數據內容和標籤

#得到數據內容和標籤
def __getitem__(self, index):
    data = torch.Tensor(self.Data[index])
    label = torch.Tensor(self.Label[index])
    return data, label

(4)最終這個子類定義爲:

import torch
import torch.utils.data.dataset as Dataset
import numpy as np
#創建子類
class subDataset(Dataset.Dataset):
    #初始化,定義數據內容和標籤
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    #返回數據集大小
    def __len__(self):
        return len(self.Data)
    #得到數據內容和標籤
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.Tensor(self.Label[index])
        return data, label

值得注意的地方是:

class subDataset(Dataset.Dataset):

如果只寫了Dataset而不是Dataset.Dataset,則會報錯:module.__init__() takes at most 2 arguments (3 given)

                                       

因爲Dataset是module模塊,不是class類,所以需要調用module裏的class才行,因此是Dataset.Dataset!

3.在類外對Data和Label賦值:

Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])

4.聲明主函數,主函數創建一個子類的對象,傳入Data和Label參數

if __name__ == '__main__':
    dataset = subDataset(Data, Label)

5.輸出數據集大小和數據:

    print(dataset)
    print('dataset大小爲:', dataset.__len__())
    print(dataset.__getitem__(0))
    print(dataset[0])

代碼變爲;

 

import torch
import torch.utils.data.dataset as Dataset
import numpy as np

Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
#創建子類
class subDataset(Dataset.Dataset):
    #初始化,定義數據內容和標籤
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    #返回數據集大小
    def __len__(self):
        return len(self.Data)
    #得到數據內容和標籤
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.IntTensor(self.Label[index])
        return data, label

if __name__ == '__main__':
    dataset = subDataset(Data, Label)
    print(dataset)
    print('dataset大小爲:', dataset.__len__())
    print(dataset.__getitem__(0))
    print(dataset[0])

結果爲:

                                                     

三、DataLoader的創建和使用

1.引入DataLoader:

import torch.utils.data.dataloader as DataLoader

2. 創建DataLoader,batch_size設置爲2,shuffle=False不打亂數據順序,num_workers= 4使用4個子進程

    #創建DataLoader迭代器
    dataloader = DataLoader.DataLoader(dataset,batch_size= 2, shuffle = False, num_workers= 4)

3.使用enumerate訪問可遍歷的數組對象:

    for i, item in enumerate(dataloader):
        print('i:', i)
        data, label = item
        print('data:', data)
        print('label:', label)

4.最終代碼如下:

import torch
import torch.utils.data.dataset as Dataset
import torch.utils.data.dataloader as DataLoader
import numpy as np

Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
#創建子類
class subDataset(Dataset.Dataset):
    #初始化,定義數據內容和標籤
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    #返回數據集大小
    def __len__(self):
        return len(self.Data)
    #得到數據內容和標籤
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.IntTensor(self.Label[index])
        return data, label

if __name__ == '__main__':
    dataset = subDataset(Data, Label)
    print(dataset)
    print('dataset大小爲:', dataset.__len__())
    print(dataset.__getitem__(0))
    print(dataset[0])

    #創建DataLoader迭代器
    dataloader = DataLoader.DataLoader(dataset,batch_size= 2, shuffle = False, num_workers= 4)
    for i, item in enumerate(dataloader):
        print('i:', i)
        data, label = item
        print('data:', data)
        print('label:', label)

結果爲:

                                                  

可以看到兩個對象,因爲對象數*batch_size就是數據集的大小__len__

*四、將Dataset數據和標籤放在GPU上(會有bug)

1.改寫__getitem__函數

        if torch.cuda.is_available():
            data = data.cuda()
            label = label.cuda()

代碼變爲:

    #得到數據內容和標籤
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.IntTensor(self.Label[index])
        if torch.cuda.is_available():
            data = data.cuda()
            label = label.cuda()
        return data, label

2.報錯啦:

 文字描述爲:

THCudaCheck FATIHCudaCheck FAIL file=Lc:\n efwile=-builder_3\win-whce:el\\pnyteorwch-\tborucihl\cdsrec\rge_3n\weirinc\StorageSharing.cpp-w helienl\epy=t2or3ch1\ toercrhr\cosrrc=\g71e ne:r ioc\pSteorartagieSohanr niotng .cspupppo line=231 error=rt7e1d
: operProcess Process-2:
ation not supportedTraceback (most recent call last):

  File "D:\Anaconda3\lib\multiprocessing\process.py", line 258, in _bootstrap
    self.run()
  File "D:\Anaconda3\lib\multiprocessing\process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "D:\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 110, in _worker_loop
    data_queue.put((idx, samples))
Process Process-1:
  File "D:\Anaconda3\lib\multiprocessing\queues.py", line 341, in put
    obj = _ForkingPickler.dumps(obj)
  File "D:\Anaconda3\lib\multiprocessing\reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
  File "D:\Anaconda3\lib\site-packages\torch\multiprocessing\reductions.py", line 109, in reduce_tensor
    (device, handle, storage_size, storage_offset) = storage._share_cuda_()
RuntimeError: cuda runtime error (71) : operation not supported at c:\new-builder_3\win-wheel\pytorch\torch\csrc\generic\StorageSharing.cpp:231
Traceback (most recent call last):
  File "D:\Anaconda3\lib\multiprocessing\process.py", line 258, in _bootstrap
    self.run()
  File "D:\Anaconda3\lib\multiprocessing\process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "D:\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 110, in _worker_loop
    data_queue.put((idx, samples))
  File "D:\Anaconda3\lib\multiprocessing\queues.py", line 341, in put
    obj = _ForkingPickler.dumps(obj)
  File "D:\Anaconda3\lib\multiprocessing\reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
  File "D:\Anaconda3\lib\site-packages\torch\multiprocessing\reductions.py", line 109, in reduce_tensor
    (device, handle, storage_size, storage_offset) = storage._share_cuda_()
RuntimeError: cuda runtime error (71) : operation not supported at c:\new-builder_3\win-wheel\pytorch\torch\csrc\generic\StorageSharing.cpp:231

其實,這是因爲我們使用多個子進程的緣故,博主在這篇博客裏就遇到了這樣的問題:

vs2017 ESRGAN(Enhanced SRGAN)的PyTorch實現


這是Dataloadernum_workers問題,這是PyTorch在windows10上的bug,詳情請見:

(1)Pytorch Windows EOFError: Ran out of input when num_workers>0

(2)EOFError: Ran out of input when enumerating the Train Loader

可見Linux系統完成深度學習還是不錯的~


3.那怎麼辦呢?事先聲明博主使用的是torch 0.4.1版本,我相信在今後可以把這樣的bug解決了。那麼現在我們只需要將num_workers改成0即可

dataloader = DataLoader.DataLoader(dataset,batch_size= 2, shuffle = False, num_workers= 0)

代碼變爲:

import torch
import torch.utils.data.dataset as Dataset
import torch.utils.data.dataloader as DataLoader
import numpy as np

Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
#創建子類
class subDataset(Dataset.Dataset):
    #初始化,定義數據內容和標籤
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    #返回數據集大小
    def __len__(self):
        return len(self.Data)
    #得到數據內容和標籤
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.IntTensor(self.Label[index])
        if torch.cuda.is_available():
            data = data.cuda()
            label = label.cuda()
        return data, label

if __name__ == '__main__':
    dataset = subDataset(Data, Label)
    print(dataset)
    print('dataset大小爲:', dataset.__len__())
    print(dataset.__getitem__(0))
    print(dataset[0][0])

    #創建DataLoader迭代器
    dataloader = DataLoader.DataLoader(dataset,batch_size= 2, shuffle = False, num_workers= 0)
    for i, item in enumerate(dataloader):
        print('i:', i)
        data, label = item
        print('data:', data)
        print('label:', label)

結果爲:

                       

可以看到多了一個device='cuda:0'

五、Dataset和DataLoader總結

1.Dataset是一個抽象類,需要派生一個子類構造數據集,需要改寫的方法有__init____getitem__等。

2.DataLoader是一個迭代器,方便我們訪問Dataset裏的對象,值得注意的num_workers的參數設置:windows系統如果放在cpu上跑,可以不管,但是放在GPU上則需要設置爲0Linux沒有這樣的問題

3.數據和標籤是tuple元組的形式,使用Dataloader然後使用enumerate函數訪問它們

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章