Pytorch基础:数据加载和预处理

Pytorch基础:数据加载和预处理

Pytorch通过torch.utils.data对数据实现封装,可以容易的实现多线程数据预读和批量加载

import torch
torch.__version__
'1.1.0'

Dataset

Dataset是一个抽象类,为方便读取,需要将使用的数据包装为Dataset类。自定义Dataset需要继承它并实现他的两个方法:

  1. getitem() 该方法定义用索引(0到self.len)获取一条数据或一个样本
  2. len() 该方法返回数据总长度
from torch.utils.data import Dataset
import numpy as np


# 定义一个数据类
class Diabetes(Dataset):
    def __init__(self):
        super(Diabetes, self).__init__()
        data = np.loadtxt('.//data//diabetes.csv.gz',
                          delimiter=',',
                          dtype=np.float32)
        self.len = data.shape[0]
        self.x_data = torch.from_numpy(data[:, 0:-1])
        self.y_data = torch.from_numpy(data[:, [-1]])

    # 根据index返回一行数据
    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        # 返回data长度
        return self.len

len 方法可以直接使用len获取数据总数

diabetes = Diabetes()
len(diabetes)
759

DataLoader

DataLoader提供了对Dataset的读取操作,常用的参数:batch_size(每个批次大小),shuffle(是否进行shuffle操作),num_workers(加载数据时使用几个子进程)

d = torch.utils.data.DataLoader(diabetes,
                                batch_size=10,
                                shuffle=True,
                                num_workers=0)

DataLoader返回一个可迭代对象,可使用迭代器分批次获取

itdata = iter(d)
next(itdata)
[tensor([[ 0.7647,  0.3668,  0.1475, -0.3535, -0.7400,  0.1058, -0.9360, -0.2667],
         [-0.8824,  0.0050,  0.0820, -0.6970, -0.8676, -0.2966, -0.4979, -0.8333],
         [-0.5294,  0.3668,  0.1475,  0.0000,  0.0000, -0.0700, -0.0572, -0.9667],
         [ 0.0000,  0.4171,  0.0000,  0.0000,  0.0000,  0.2638, -0.8915, -0.7333],
         [-0.7647,  0.0854,  0.0164, -0.3535, -0.8676, -0.2489, -0.9573,  0.0000],
         [-0.5294,  0.1256,  0.2787, -0.1919,  0.0000,  0.1744, -0.8651, -0.4333],
         [-0.7647, -0.1859, -0.0164, -0.5556,  0.0000, -0.1744, -0.8190, -0.8667],
         [-0.7647,  0.1859,  0.3115,  0.0000,  0.0000,  0.2787, -0.4748,  0.0000],
         [-0.8824,  0.1256,  0.3115, -0.0909, -0.6879,  0.0373, -0.8813, -0.9000],
         [ 0.0000,  0.4673,  0.3443,  0.0000,  0.0000,  0.2072,  0.4543, -0.2333]]),
 tensor([[0.],
         [1.],
         [0.],
         [0.],
         [1.],
         [1.],
         [1.],
         [0.],
         [1.],
         [1.]])]
# 常见用法是使用for循环遍历
for i, data in enumerate(d):
    print(i, data)
    break
0 [tensor([[ 0.0000,  0.1859,  0.3770, -0.0505, -0.4563,  0.3651, -0.5961, -0.6667],
        [ 0.0588, -0.1055,  0.0164,  0.0000,  0.0000, -0.3294, -0.9453, -0.6000],
        [ 0.0000,  0.1759,  0.0820, -0.3737, -0.5556, -0.0820, -0.6456, -0.9667],
        [ 0.1765,  0.6884,  0.2131,  0.0000,  0.0000,  0.1326, -0.6080, -0.5667],
        [-0.7647,  0.4673,  0.0000,  0.0000,  0.0000, -0.1803, -0.8617, -0.7667],
        [-0.1765,  0.1457,  0.2459, -0.6566, -0.7400, -0.2906, -0.6687, -0.6667],
        [-0.7647,  0.1256,  0.0820, -0.5556,  0.0000, -0.2548, -0.8044, -0.9000],
        [-0.8824,  0.3367,  0.6721, -0.4343, -0.6690, -0.0224, -0.8668, -0.2000],
        [-0.8824,  0.6784,  0.2131, -0.6566, -0.6596, -0.3025, -0.6849, -0.6000],
        [-0.8824,  0.1658,  0.2787, -0.4141, -0.5745,  0.0760, -0.6430, -0.8667]]), tensor([[0.],
        [1.],
        [1.],
        [0.],
        [0.],
        [1.],
        [1.],
        [0.],
        [0.],
        [1.]])]

torchvision

torchvision是Pytorch中用来处理图像的库
torchvision.datasets 为Pytorch官方定义的dataset:可直接使用MNIST、COCO、Detetion、LSUN、CIFAR10等

from torchvision import datasets, transforms
trainset = datasets.MNIST(
    root='.//data//',  # 加载MNIST数据的目录
    train=True,  # 标识加载数据集,为false时为测试集
    download=True,  # 是否自动下载数据
    transform=True)  # 是否需要对数据进行预处理, None时不进行预处理

torchvision.models

torchvision还提供了训练好的模型,可以在进行迁移学习torchvision.models模块的子模块中包含以下结构:

  • AlexNet
  • VGG
  • ResNet
  • SqueezeNet
  • DenseNet
from torchvision import models
resnet18 = models.resnet18(pretrained=True)
Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to C:\Users\Zephyrus/.cache\torch\checkpoints\resnet18-5c106cde.pth



---------------------------------------------------------------------------

torchvision.transforms

transforms模块提供了一般的图像转换操作类,用于数据处理和数据增强

from torchvision import transforms

transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  # 先四周填充0,把图像随机裁剪成32x32
    transforms.RandomHorizontalFlip(),  # 把图像一般概率翻转,一半的概率不翻转
    transforms.RandomRotation((-45, 45)),  # 随机旋转
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.229, 0.224, 0.225))  # RGB每层的归一化用到的均值和方差
])

关于(0.4914, 0.4822, 0.4465),(0.229, 0.224, 0.225)详情说明,这些是根据ImageNet训练的归一化参数,可以直接使用,可认为为固定值


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章