Mark下在自定义DataLoader时犯的一个隐性错误,最后还是通过阅读源码发现了症结。
场景:根据一个dataframe自定义DataLoader
原始数据:
import pandas as pd
import random
df = pd.DataFrame({'feature':[[random.randint(0, 5) for _ in range(5)] for _ in range(10)], 'label':[random.randint(0, 5) for _ in range(10)]})
想当然的写下了如下代码:
# 用元组记录每一行数据
def load_data(data):
contents = []
for i in range(len(data)):
feature = data.iloc[i]['feature']
label = data.iloc[i]['label']
contents.append((feature, label))
return contents
class Dataset1(Dataset):
def __init__(self, data):
self.data = data
def __getitem__(self, item):
return self.data[item]
def __len__(self):
return len(self.data)
dataset1 = Dataset1(result)
dataloader1 = DataLoader(dataset1, batch_size=2)
一切似乎很寻常,下面开始表演真正的技术:
for x,y in dataloader2:
print(x)
print(y)
"""[tensor([0, 2]), tensor([3, 5]), tensor([4, 3]), tensor([2, 5]), tensor([3, 1])]
tensor([5, 1])
[tensor([0, 3]), tensor([4, 0]), tensor([1, 3]), tensor([4, 2]), tensor([0, 1])]
tensor([2, 0])
[tensor([3, 4]), tensor([0, 1]), tensor([0, 2]), tensor([0, 2]), tensor([4, 1])]
tensor([5, 1])
[tensor([1, 1]), tensor([0, 5]), tensor([0, 0]), tensor([2, 1]), tensor([5, 5])]
tensor([1, 5])
[tensor([3, 3]), tensor([5, 3]), tensor([2, 0]), tensor([0, 4]), tensor([4, 4])]
tensor([2, 2])"""
???完全不是想象中的结果,再挣扎一下:
class Dataset2(Dataset):
def __init__(self, data):
self.data = data
def __getitem__(self, item):
return {'feature': self.data[item][0], 'label':self.data[item][1]},
def __len__(self):
return len(self.data)
dataset2 = Dataset2(result)
dataloader2 = DataLoader(dataset2, batch_size=2)
for x in dataloader2:
print(x['feature'])
print(x['label'])
这次直接报错了:
TypeError: list indices must be integers or slices, not str
看下实际的迭代结果:
for x in dataloader2:
print(x)
'''
[{'feature': [tensor([0, 2]), tensor([3, 5]), tensor([4, 3]), tensor([2, 5]), tensor([3, 1])], 'label': tensor([5, 1])}]
[{'feature': [tensor([0, 3]), tensor([4, 0]), tensor([1, 3]), tensor([4, 2]), tensor([0, 1])], 'label': tensor([2, 0])}]
[{'feature': [tensor([3, 4]), tensor([0, 1]), tensor([0, 2]), tensor([0, 2]), tensor([4, 1])], 'label': tensor([5, 1])}]
[{'feature': [tensor([1, 1]), tensor([0, 5]), tensor([0, 0]), tensor([2, 1]), tensor([5, 5])], 'label': tensor([1, 5])}]
[{'feature': [tensor([3, 3]), tensor([5, 3]), tensor([2, 0]), tensor([0, 4]), tensor([4, 4])], 'label': tensor([2, 2])}]
'''
还是和预想的结果完全不一样,究竟哪里出了问题呢?只能查看源码了,在生成DataLoader
的每个batch时,会调用collat_fn
进行数据的整合和转化,而默认的default_collate
如下:
def default_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
elem_type = type(batch[0])
if isinstance(batch[0], torch.Tensor):
out = None
if _use_shared_memory:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = batch[0].storage()._new_shared(numel)
out = batch[0].new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
elem = batch[0]
if elem_type.__name__ == 'ndarray':
# array of string classes and object
if re.search('[SaUO]', elem.dtype.str) is not None:
raise TypeError(error_msg.format(elem.dtype))
return torch.stack([torch.from_numpy(b) for b in batch], 0)
if elem.shape == (): # scalars
py_type = float if elem.dtype.name.startswith('float') else int
return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
elif isinstance(batch[0], int_classes):
return torch.LongTensor(batch)
elif isinstance(batch[0], float):
return torch.DoubleTensor(batch)
elif isinstance(batch[0], string_classes):
return batch
elif isinstance(batch[0], container_abcs.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
elif isinstance(batch[0], container_abcs.Sequence):
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
raise TypeError((error_msg.format(type(batch[0]))))
显然,DataLoader会根据数据的类型进行转化,对于上面的方案一,其类型属于container_abcs.Sequence
,其会迭代进行zip
和序列化操作,而对于方案二,其每个key
的元素也为container_abcs.Sequence
,会进行相同的处理,而这些会导致上述的问题。我们来实验下:
for i in zip(*([0, 3, 4, 2, 3], [2, 5, 3, 5, 1])):
print(i)
这和方案一的生成结果是一致的。
因此,采用container_abcs.Sequence
类型的数据来封装每条数据,是有风险的,其每列/特征上的元素会被视为是异构的,所以会逐个进行处理。那么该如何规避上述问题呢?
一种可信的解决方案是将container_abcs.Sequence
类型的数据转换为np.array
类型,从而避免zip
操作。
import numpy as np
class Dataset3(Dataset):
def __init__(self, data):
self.data = data
def __getitem__(self, item):
return np.array(self.data[item][0]), self.data[item][1]
def __len__(self):
return len(self.data)
dataset3 = Dataset3(result)
dataloader3 = DataLoader(dataset3, batch_size=2)
for x,y in dataloader3:
print(x)
print(y)
"""
tensor([[0, 3, 4, 2, 3],
[2, 5, 3, 5, 1]])
tensor([5, 1])
tensor([[0, 4, 1, 4, 0],
[3, 0, 3, 2, 1]])
tensor([2, 0])
tensor([[3, 0, 0, 0, 4],
[4, 1, 2, 2, 1]])
tensor([5, 1])
tensor([[1, 0, 0, 2, 5],
[1, 5, 0, 1, 5]])
tensor([1, 5])
tensor([[3, 5, 2, 0, 4],
[3, 3, 0, 4, 4]])
tensor([2, 2])
"""
一切终于回复正常了!
教训:事出反常必有妖,简单的API可能蕴含了很多tricks,出现问题要发现问题的点,通过分析源码了解其运作机制。可以尝试写类似的API来加强对底层逻辑的认知。