2 DataLoader-庖丁解牛之pytorch

數據集已經有了,直接使用不就得了,實際數據加載是一個很大的問題,涉及內存、cpu、GPU的利用關係,因此專門設計一個數據加載類DataLoader,我們先看一看這個類的參數

* dataset (Dataset): 裝載的數據集
* batch_size (int, optional): 每批加載批次大小,默認1
* shuffle (bool, optional): 每個epoch是否混淆
* sampler (Sampler, optional): 採樣器,與shuffle互斥
* batch_sampler (Sampler, optional): 和sampler類似,
* num_workers (int, optional): 多進程併發裝載,subprocess工作進程個數,默認0
* collate_fn (callable, optional): 合併mini-batch的採樣列表
* pin_memory (bool, optional): 鎖頁內存
* drop_last (bool, optional): 丟棄最後一個不完整的batch
* timeout (numeric, optional): 收集工作批次的等待時間    
* worker_init_fn (callable, optional): 每個工作進程根據worker ID調用

參數一大堆,但是函數就三個

__setattr__(self, attr, val)  設置屬性
__iter__(self)                    迭代
__len__(self)                     長度

採樣器

我們先看看採樣器
採樣器有如下幾個

Sampler 基本採樣器基類
SequentialSampler 序列採樣器 iter(range(len(self.data_source)))
RandomSampler 隨機採樣器iter(torch.randperm(len(self.data_source)).tolist())
SubsetRandomSampler 子集隨機採樣器 (self.indices[i] for i in torch.randperm(len(self.indices)))
WeightRandomSampler 權重隨機採樣器iter(torch.multinomial(self.weights, self.num_samples, self.replacement))
BatchSampler 批處理採樣器
DistributedSampler 分佈採樣器
from torch.utils.data import BatchSampler, SequentialSampler
list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))

目前系統的採樣器只有這幾種,對於DataLoader來說批次也是採樣過程,因此都歸結爲採樣器。DataLoader的最重要一個函數是迭代器

迭代器

迭代器根據採樣器的處理,利用多線程技術,分批次進行加載,這也是DataLoader的核心,該進程首先申請兩類隊列,一類是索引隊列,一類是工作結果隊列,用於存儲進程之間的結果。之後引入最重要的工作進程_worker_loop這是一個全局函數,從索引隊列中領取任務,將結果放到工作結果隊列中,源碼如下:

......
    while True:
        try:
            # 領任務
            r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
        except queue.Empty:
            if watchdog.is_alive():
                continue
            else:
                break
        if r is None:
            break
        idx, batch_indices = r
        try:
            samples = collate_fn([dataset[i] for i in batch_indices]) # 幹活
        except Exception:
            data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) # 撂挑子
        else:
            data_queue.put((idx, samples)) # 交結果
            del samples
......

工作管理進程收集上交結果

def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id):
    if pin_memory:
        torch.cuda.set_device(device_id)

    while True:
        try:
            r = in_queue.get()
        except Exception:
            if done_event.is_set():
                return
            raise
        if r is None:
            break
        if isinstance(r[1], ExceptionWrapper):
            out_queue.put(r)
            continue
        idx, batch = r
        try:
            if pin_memory:
                batch = pin_memory_batch(batch)
        except Exception:
            out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
        else:
            out_queue.put((idx, batch))

在工作管理進程收集結果的時候有個操作比較特別,pin_memory_batch稱鎖頁內存(pinned memory or page locked memory):創建DataLoader時,設置pin_memory=True,則意味着生成的Tensor數據最開始是屬於內存中的鎖頁內存,這樣將內存的Tensor轉義到GPU的顯存就會更快一些。
主機中的內存,有兩種存在方式,一是鎖頁,二是不鎖頁,鎖頁內存存放的內容在任何情況下都不會與主機的虛擬內存進行交換(注:虛擬內存就是硬盤),而不鎖頁內存在主機內存不足時,數據會存放在虛擬內存中。
而顯卡中的顯存全部是鎖頁內存。
當計算機的內存充足的時候,可以設置pin_memory=True。當系統卡住,或者交換內存使用過多的時候,設置pin_memory=False。因爲pin_memory與電腦硬件性能有關,pytorch開發者不能確保每一個煉丹玩家都有高端設備,因此pin_memory默認爲False。
數據裝載部分主要有迭代器來實現,此處代碼不清晰,主要過程就是多線程、內存管理、分批讀入等

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章