torch.utils.data.dataloader參數collate_fn簡析

torch.utils.data.DataLoader是pytorch提供的數據加載類,初始化函數如下,

torch.utils.data.DataLoader(dataset,batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)

dataset,batch_size等參數重要且容易理解,而collate_fn參數就不太直白,官方解釋爲:

collate_fn (callableoptional) – merges a list of samples to form a mini-batch

不明不白。

其實,collate_fn可理解爲函數句柄、指針...或者其他可調用類(實現__call__函數)。 函數輸入爲list,list中的元素爲欲取出的一系列樣本。具體如下

indices = next(self.sample_iter)
batch = self.collate_fn([dataset[i] for i in indices])

其中self.sampler_iter即採樣器,返回下一個batch中樣本的序號,indices。

通過collate_fn函數可以對這些樣本做進一步的處理(任何你想要的處理),原則上返回值應當是一個有結構的batch。而DataLoader每次迭代的返回值就是collate_fn的返回值。

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章