torch.utils.data.DataLoader是pytorch提供的數據加載類,初始化函數如下,
torch.utils.data.DataLoader(dataset,batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
dataset,batch_size等參數重要且容易理解,而collate_fn參數就不太直白,官方解釋爲:
collate_fn (callable, optional) – merges a list of samples to form a mini-batch
不明不白。
其實,collate_fn可理解爲函數句柄、指針...或者其他可調用類(實現__call__函數)。 函數輸入爲list,list中的元素爲欲取出的一系列樣本。具體如下
indices = next(self.sample_iter)
batch = self.collate_fn([dataset[i] for i in indices])
其中self.sampler_iter即採樣器,返回下一個batch中樣本的序號,indices。
通過collate_fn函數可以對這些樣本做進一步的處理(任何你想要的處理),原則上返回值應當是一個有結構的batch。而DataLoader每次迭代的返回值就是collate_fn的返回值。