数据流
Dataset
pytorch实现了一个基类Dataset来帮助构建数据集对象。要想实现自己的数据集类需要基于torch.utils.data.Dataset
来完成,并需要在类中实现两个方法,分别是:
__getitem__:用于返回一条数据,DATA_CLASS[index]相当于DATA_CLASS.__getitem__(index)
__len__:用于返回样本数量,len(DATA_CLASS)相当于DATA_CLASS.__len__()
ImageFolder
ImageFolder是torch在Dataset类基础上实现的一个常用的数据集对象。它假设所有文件都按文件夹保存,每个文件夹储存同一类别的图像,文件夹为类名。如果是分类任务,则可通过调用ImageFolder类来完成数据集加载,不需要再自己写了。
DataLoader
Dataset仅用于数据加载->返回样本,但实际训练时,还需要对数据进行shuffle和并行加速。DataLoader用于实现这些功能。
一些常用函数
- permute
permute用于对原数据进行维度调换,例如x.permute(1,0,2),将原本的第一维换成第0维 - squeeze和unsqueeze
squeeze用于去掉元素数量为1的维度,例如x.squeeze()
unsqueeze用于增加维度,例如b.unsqueeze(dim)