Pytorch | dataloader 多線程下numpy每個線程隨機種子都一樣解決方案。

問題描述

  • pytorch的Dataloader用於加載數據。在num_works >1時, 每個線程中numpy.random產生的隨機數
    一樣,也就是隨機種子相同。random 和 torch兩個模塊的隨機數不會出現這種情況。
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)
  • 復現問題代碼
from torch.utils.data import  Dataset
from torch.utils.data import  DataLoader
import numpy as np
class NthreadDateset(Dataset):
    def __init__(self):
        self.datas = np.arange(100)

    def __len__(self):
        return len(self.datas)

    def __getitem__(self, index):
        data = self.datas[index]
        random_data = np.random.uniform(0.0, 1.0)

        return  data, random_data

if __name__ == '__main__':
    datasets = NthreadDateset()
    data_loader = DataLoader(datasets,
                             num_workers=4,
                             shuffle=True,
    )
    for i, data in enumerate(data_loader):
        print(data)

運行結果

[tensor([5]), tensor([0.8464], dtype=torch.float64)]
[tensor([92]), tensor([0.8464], dtype=torch.float64)]
[tensor([46]), tensor([0.8464], dtype=torch.float64)]
[tensor([44]), tensor([0.8464], dtype=torch.float64)]
[tensor([69]), tensor([0.9780], dtype=torch.float64)]
[tensor([60]), tensor([0.9780], dtype=torch.float64)]
[tensor([53]), tensor([0.9780], dtype=torch.float64)]
[tensor([12]), tensor([0.9780], dtype=torch.float64)]
[tensor([0]), tensor([0.6385], dtype=torch.float64)]
[tensor([33]), tensor([0.6385], dtype=torch.float64)]
[tensor([85]), tensor([0.6385], dtype=torch.float64)]
[tensor([96]), tensor([0.6385], dtype=torch.float64)]

如結果所示,四個線程的隨機數都一樣 0.8464, 0.9780…

解決方案

  • 官方推薦:通過新建worker_init_fn 來設置不同線程的種子。
By default, each worker will have its PyTorch seed set to base_seed + worker_id, where base_seed is a long generated by main process using its RNG (thereby, consuming a RNG state mandatorily). However, seeds for other libraries may be duplicated upon initializing workers (w.g., NumPy), causing each worker to return identical random numbers. (See this section in FAQ.).

In worker_init_fn, you may access the PyTorch seed set for each worker with either torch.utils.data.get_worker_info().seed or torch.initial_seed(), and use it to seed other libraries before data loading.
  • 具體實現
    • 定義worker_init_fn函數, 讓每個線程numpy.random的種子不同。
def worker_init_fn_seed(worker_id):
    seed = 10
    seed += worker_id
    np.random.seed(seed)
	
    print(worker_id)
if __name__ == '__main__':

    datasets = NthreadDateset()
    data_loader = DataLoader(datasets,
                             num_workers=4,
                             shuffle=True,
                             worker_init_fn= worker_init_fn_seed

    )
    for i, data in enumerate(data_loader):
        print(data)
  • 結果,每個線程內都隨機
[tensor([94]), tensor([0.7713], dtype=torch.float64)]
[tensor([87]), tensor([0.1803], dtype=torch.float64)]
[tensor([47]), tensor([0.1542], dtype=torch.float64)]
[tensor([90]), tensor([0.7777], dtype=torch.float64)]
[tensor([66]), tensor([0.0208], dtype=torch.float64)]
[tensor([65]), tensor([0.0195], dtype=torch.float64)]
[tensor([97]), tensor([0.7400], dtype=torch.float64)]
[tensor([57]), tensor([0.2375], dtype=torch.float64)]
[tensor([59]), tensor([0.6336], dtype=torch.float64)]
[tensor([21]), tensor([0.4632], dtype=torch.float64)]
[tensor([7]), tensor([0.2633], dtype=torch.float64)]
[tensor([35]), tensor([0.8243], dtype=torch.float64)]
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章