方案一:定義迭代器類
自定義數據迭代器類,主要需要申明如下方法:
(1)__init__
:初始化方法,可用於傳入數據路徑或導入數據;
(2)__iter__
:定義返回的可迭代對象,在自定義數據迭代器類時,直接返回本對象self即可;
(3)__next__
:定義迭代器對象,用於next()
方法或for
方法,逐個讀出數據;
(4)__len__
:定義可迭代器長度
下面是一個簡單示例:
class DataLoader(object):
"""
batch_size: 批量大小
drop_last:最後一個batch的不完整數據是否需要裁剪
"""
def __init__(self, data, batch_size, drop_last=False):
self.data = data
self.batch_size = batch_size
self.drop_last = drop_last
self.index = 0
# 根據參數計算實際的batch數
if len(self.data) % self.batch_size != 0 and not self.drop_last:
self.batch = len(self.data) // self.batch_size + 1
else:
self.batch = len(self.data) // self.batch_size
# 返回自己,作爲可迭代對象
def __iter__(self):
return self
# 每次迭代的數據
def __next__(self):
if self.index < self.batch:
data = self.data[self.index*self.batch_size:(self.index+1)*self.batch_size]
self.index += 1
return data
else:
raise StopIteration
def __len__(self):
return self.batch
驗證上述數據迭代器類的可行性
data = list(range(10))
dataloader1 = DataLoader(data, 5)
for i in dataloader1:
print(i)
# [0, 1, 2, 3, 4]
# [5, 6, 7, 8, 9]
dataloader2 = DataLoader(data, 4)
for i in dataloader2:
print(i)
# [0, 1, 2, 3]
# [4, 5, 6, 7]
# [8, 9]
方案二:直接利用yield函數
def DataLoader(data, batch_size, drop_last=False):
minibatch, size_sofar =[], 0
for i in data:
minibatch.append(i)
size_sofar += 1
if size_sofar == batch_size:
yield minibatch
minibatch, size_sofar =[], 0
if not drop_last and minibatch:
yield minibatch
驗證該函數的可行性
data = list(range(10))
for i in DataLoader(data, batch_size=5):
print(i)
# [0, 1, 2, 3, 4]
# [5, 6, 7, 8, 9]
for i in DataLoader(data, batch_size=4, drop_last=True):
print(i)
# [0, 1, 2, 3]
# [4, 5, 6, 7]
for i in DataLoader(data, batch_size=4, drop_last=False):
print(i)
# [0, 1, 2, 3]
# [4, 5, 6, 7]
# [8, 9]