Pytorch基礎:數據加載和預處理

Pytorch基礎:數據加載和預處理

Pytorch通過torch.utils.data對數據實現封裝,可以容易的實現多線程數據預讀和批量加載

import torch
torch.__version__
'1.1.0'

Dataset

Dataset是一個抽象類,爲方便讀取,需要將使用的數據包裝爲Dataset類。自定義Dataset需要繼承它並實現他的兩個方法:

  1. getitem() 該方法定義用索引(0到self.len)獲取一條數據或一個樣本
  2. len() 該方法返回數據總長度
from torch.utils.data import Dataset
import numpy as np


# 定義一個數據類
class Diabetes(Dataset):
    def __init__(self):
        super(Diabetes, self).__init__()
        data = np.loadtxt('.//data//diabetes.csv.gz',
                          delimiter=',',
                          dtype=np.float32)
        self.len = data.shape[0]
        self.x_data = torch.from_numpy(data[:, 0:-1])
        self.y_data = torch.from_numpy(data[:, [-1]])

    # 根據index返回一行數據
    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        # 返回data長度
        return self.len

len 方法可以直接使用len獲取數據總數

diabetes = Diabetes()
len(diabetes)
759

DataLoader

DataLoader提供了對Dataset的讀取操作,常用的參數:batch_size(每個批次大小),shuffle(是否進行shuffle操作),num_workers(加載數據時使用幾個子進程)

d = torch.utils.data.DataLoader(diabetes,
                                batch_size=10,
                                shuffle=True,
                                num_workers=0)

DataLoader返回一個可迭代對象,可使用迭代器分批次獲取

itdata = iter(d)
next(itdata)
[tensor([[ 0.7647,  0.3668,  0.1475, -0.3535, -0.7400,  0.1058, -0.9360, -0.2667],
         [-0.8824,  0.0050,  0.0820, -0.6970, -0.8676, -0.2966, -0.4979, -0.8333],
         [-0.5294,  0.3668,  0.1475,  0.0000,  0.0000, -0.0700, -0.0572, -0.9667],
         [ 0.0000,  0.4171,  0.0000,  0.0000,  0.0000,  0.2638, -0.8915, -0.7333],
         [-0.7647,  0.0854,  0.0164, -0.3535, -0.8676, -0.2489, -0.9573,  0.0000],
         [-0.5294,  0.1256,  0.2787, -0.1919,  0.0000,  0.1744, -0.8651, -0.4333],
         [-0.7647, -0.1859, -0.0164, -0.5556,  0.0000, -0.1744, -0.8190, -0.8667],
         [-0.7647,  0.1859,  0.3115,  0.0000,  0.0000,  0.2787, -0.4748,  0.0000],
         [-0.8824,  0.1256,  0.3115, -0.0909, -0.6879,  0.0373, -0.8813, -0.9000],
         [ 0.0000,  0.4673,  0.3443,  0.0000,  0.0000,  0.2072,  0.4543, -0.2333]]),
 tensor([[0.],
         [1.],
         [0.],
         [0.],
         [1.],
         [1.],
         [1.],
         [0.],
         [1.],
         [1.]])]
# 常見用法是使用for循環遍歷
for i, data in enumerate(d):
    print(i, data)
    break
0 [tensor([[ 0.0000,  0.1859,  0.3770, -0.0505, -0.4563,  0.3651, -0.5961, -0.6667],
        [ 0.0588, -0.1055,  0.0164,  0.0000,  0.0000, -0.3294, -0.9453, -0.6000],
        [ 0.0000,  0.1759,  0.0820, -0.3737, -0.5556, -0.0820, -0.6456, -0.9667],
        [ 0.1765,  0.6884,  0.2131,  0.0000,  0.0000,  0.1326, -0.6080, -0.5667],
        [-0.7647,  0.4673,  0.0000,  0.0000,  0.0000, -0.1803, -0.8617, -0.7667],
        [-0.1765,  0.1457,  0.2459, -0.6566, -0.7400, -0.2906, -0.6687, -0.6667],
        [-0.7647,  0.1256,  0.0820, -0.5556,  0.0000, -0.2548, -0.8044, -0.9000],
        [-0.8824,  0.3367,  0.6721, -0.4343, -0.6690, -0.0224, -0.8668, -0.2000],
        [-0.8824,  0.6784,  0.2131, -0.6566, -0.6596, -0.3025, -0.6849, -0.6000],
        [-0.8824,  0.1658,  0.2787, -0.4141, -0.5745,  0.0760, -0.6430, -0.8667]]), tensor([[0.],
        [1.],
        [1.],
        [0.],
        [0.],
        [1.],
        [1.],
        [0.],
        [0.],
        [1.]])]

torchvision

torchvision是Pytorch中用來處理圖像的庫
torchvision.datasets 爲Pytorch官方定義的dataset:可直接使用MNIST、COCO、Detetion、LSUN、CIFAR10等

from torchvision import datasets, transforms
trainset = datasets.MNIST(
    root='.//data//',  # 加載MNIST數據的目錄
    train=True,  # 標識加載數據集,爲false時爲測試集
    download=True,  # 是否自動下載數據
    transform=True)  # 是否需要對數據進行預處理, None時不進行預處理

torchvision.models

torchvision還提供了訓練好的模型,可以在進行遷移學習torchvision.models模塊的子模塊中包含以下結構:

  • AlexNet
  • VGG
  • ResNet
  • SqueezeNet
  • DenseNet
from torchvision import models
resnet18 = models.resnet18(pretrained=True)
Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to C:\Users\Zephyrus/.cache\torch\checkpoints\resnet18-5c106cde.pth



---------------------------------------------------------------------------

torchvision.transforms

transforms模塊提供了一般的圖像轉換操作類,用於數據處理和數據增強

from torchvision import transforms

transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  # 先四周填充0,把圖像隨機裁剪成32x32
    transforms.RandomHorizontalFlip(),  # 把圖像一般概率翻轉,一半的概率不翻轉
    transforms.RandomRotation((-45, 45)),  # 隨機旋轉
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.229, 0.224, 0.225))  # RGB每層的歸一化用到的均值和方差
])

關於(0.4914, 0.4822, 0.4465),(0.229, 0.224, 0.225)詳情說明,這些是根據ImageNet訓練的歸一化參數,可以直接使用,可認爲爲固定值


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章