torch学习记录

数据流

Dataset

pytorch实现了一个基类Dataset来帮助构建数据集对象。要想实现自己的数据集类需要基于torch.utils.data.Dataset来完成,并需要在类中实现两个方法,分别是:

__getitem__:用于返回一条数据,DATA_CLASS[index]相当于DATA_CLASS.__getitem__(index)
__len__:用于返回样本数量,len(DATA_CLASS)相当于DATA_CLASS.__len__()

ImageFolder
ImageFolder是torch在Dataset类基础上实现的一个常用的数据集对象。它假设所有文件都按文件夹保存,每个文件夹储存同一类别的图像,文件夹为类名。如果是分类任务,则可通过调用ImageFolder类来完成数据集加载,不需要再自己写了。

DataLoader

Dataset仅用于数据加载->返回样本,但实际训练时,还需要对数据进行shuffle和并行加速。DataLoader用于实现这些功能。

一些常用函数

  1. permute
    permute用于对原数据进行维度调换,例如x.permute(1,0,2),将原本的第一维换成第0维
  2. squeeze和unsqueeze
    squeeze用于去掉元素数量为1的维度,例如x.squeeze()
    unsqueeze用于增加维度,例如b.unsqueeze(dim)
发布了18 篇原创文章 · 获赞 16 · 访问量 2万+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章