Pytorch使用DataLoader批量加載數據

在進行模型訓練時,需要把數據按照固定的形式分批次投餵給模型,在PyTorch中通過torch.utils.data庫的DataLoader完成分批次返回數據。

構造DataLoader首先需要一個Dataset數據源,Dataset完成數據的讀取並可以返回單個數據,然後DataLoader在此基礎上完成數據清洗、打亂等操作並按批次返回數據。

Dataset

PyTorch將數據源分爲兩種類型:類似Map型(Map-style datasets)和可迭代型(Iterable-style datasets)。
Map風格的數據源可以通過索引idx對數據進行查找:dataset[idx],它需要繼承Dataset類,並且重寫__getitem__() 方法完成根據索引值獲取數據和__len__() 方法返回數據的總長度。
可迭代型可以迭代獲取其數據,但沒有固定的長度,因此也不能通過下標獲得數據,通常用於無法獲取全部數據或者流式返回的數據。它繼承自IterableDataset類,並且需要實現__iter__()方法來完成對數據集的迭代和返回。

如下所示爲自定義的數據源MySet,它完成數據的讀取,這裏假定爲[1, 9] 9個數據,然後重寫了__getitem__() 和__len__() 方法

from torch.utils.data import Dataset, DataLoader, Sampler

class MySet(Dataset):
	# 讀取數據
    def __init__(self):
        self.data = [1, 2, 3, 4, 5, 6, 7, 8, 9]
	# 根據索引返回數據
    def __getitem__(self, idx):
        return self.data[idx]
	# 返回數據集總長度
    def __len__(self):
        return len(self.data)

DataLoader

其構造函數如下:
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)

  • dataset:Dataset類型,從其中加載數據 batch_size:int,可選。每個batch加載多少樣本
  • batch_size: 一個批次的數據個數
  • shuffle:bool,可選。爲True時表示每個epoch都對數據進行洗牌
  • sampler:Sampler,可選。獲取下一個數據的方法。
  • batch_sampler :獲取下一批次數據的方法
  • num_workers:int,可選。加載數據時使用多少子進程。默認值爲0,表示在主進程中加載數據。
  • collate_fn:callable,可選,自定義處理數據並返回。
  • pin_memory:bool,可選,True代表將數據Tensor放入CUDA的pin儲存
  • drop_last:bool,可選。True表示如果最後剩下不完全的batch,丟棄。False表示不丟棄。

Sampler索引

既然DataLoader根據索引值從Dataset中獲取數據,那麼如何獲取一個批次數據的索引,索引值應該如何排列才能實現隨機的效果?這就需要Sampler了,它可以對索引進行shuffle操作來打亂順序,並且根據batch size一次返回指定個數的索引序列。在初始化DataLoader時通過sampler屬性指定獲取下一個數據的索引的方法,或者batch_sampler屬性指定獲取下一個批次數據的索引。

當我們設置DataLoader的shuffle屬性爲True時,會根據batch_size屬性傳入的批次大小自動構造sample返回下一個批次的索引。

當我們不啓用shuffle屬性時,就可以通過batch_sampler屬性自定義sample來返回下一批的索引,注意這時候不可用使用 batch_size, shuffle, sampler, 和drop_last屬性。
如下所示爲自定義MySampler,它繼承自Sampler,由傳入dataset的長度產生對應的索引,例如上面有9個數據,那麼產生索引[0, 8]。根據批次大小batch_size計算出總批次數,例如當batchsize是3,那麼9/3=3,即總共有3個批次。重寫__iter__()方法按批次返回索引,即第一批返回[0, 1, 2],第二批返回[3, 4, 5]以此類推。__len__()方法返回總的批次數,即3個批次。

class MySampler(Sampler):
    def __init__(self, dataset, batchsize):
        super(Sampler, self).__init__()
        self.dataset = dataset
        self.batch_size = batchsize		# 每一批數據量
        self.indices = range(len(dataset))	# 生成數據集的索引
        self.count = int(len(dataset) / self.batch_size)	# 一共有多少批

    def __iter__(self):
        for i in range(self.count):
            yield self.indices[i * self.batch_size: (i + 1) * self.batch_size]

    def __len__(self):
        return self.count

collate處理數據

當我們拿到數據如果希望進行一些預處理而不是直接返回,這時候就需要collate_fn屬性來指定處理和返回數據的方法,如果不指定該屬性,默認會將普通的NumPy數組轉換爲PyTorch的tensor並直接返回。
如下所示爲自定義的my_collate()函數,默認傳入獲得的一個批次的數據data,例如之前返回一批數據[1, 2, 3],這裏遍歷數據並平方之後放在res數組中返回[1, 4, 9]

def my_collate(data):
    res = []
    for d in data:
        res.append(d ** 2)
    return res

有了上面的索引獲取類MySampler和數據處理函數my_collate(),就可以使用DataLoader自定義獲取批數據了。首先DataLoader通過my_sampler返回的索引[0, 1, 2]去dataset拿到數據[1, 2, 3],然後傳遞給my_collate進行平方操作,然後返回一個批次的結果爲[1, 4, 9],一共有三個批次的數據。

dataset = MySet()	# 定義數據集
my_sampler = MySampler(dataset, 3)		# 實例化MySampler

data_loader = DataLoader(dataset, batch_sampler=my_sampler, collate_fn=my_collate)

for data in data_loader:	# 按批次獲取數據
    print(data)
'''
[1, 4, 9]
[16, 25, 36]
[49, 64, 81]
'''
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章