數據集加載,Dataset、ImageFloader

最近寫了一個簡單的程序,利用自己的數據集和resnet50 ,得到最後的特徵。其中用到了數據加載,resnet50的利用,其中resnet50的網絡好寫,但是數據集加載去網上查了才寫了出來,特此留個記錄。

參考鏈接: https://blog.csdn.net/xuan_liu123/article/details/101145366k.

1 RESNET50搭建

'''
完成網絡搭建
'''
import torch
import torch.nn as nn
from torchvision.models import resnet50

resnet = resnet50(pretrain = True)
'''
###可以通過這種方式達到修改步長的方式,從而得到特徵不一樣的地方。
resnet.layer4[0].downsample[0].stride = (1,1)
resnet.layer4[0].conv2.stride = (1,1)
'''
module = list(resnet.children())[:-2]# 去除了平均池化層和FC層
self.backbone = nn.sequential(*module)

2 數據集的操作

參考鏈接: https://blog.csdn.net/xuan_liu123/article/details/101145366k. 寫的特別好。

在網上查了好多資料,基本上分爲兩種,一種是Dataset,一種是ImaFloader

在學習Pytorch的教程時,加載數據許多時候都是直接調用 torchvision.datasets 裏面集成的數據集,直接在線下載,然後使用torch.utils.data.DataLoader進行加載。
那麼,我們怎麼使用我們自己的數據集,然後用DataLoader進行加載呢?

常見的兩種形式的導入:

一種是整個數據集都在一個文件下,內部再另附一個label文件,說明每個文件的狀態。這種存放數據的方式可能更時候在非分類問題上得到應用。
一種則是更適合在分類問題上,即把不同種類的數據分爲不同的文件夾存放起來。這樣,我們可以從文件夾或文件名得到label。
我們以貓狗數據集爲例,進行自定義加載數據。
貓狗數據集裏面有兩個文件夾,分別是test和train。
其中train文件夾下的圖片,命名方式爲:cat.0.jpg或dog.0.jpg。我們可以從文件名中提取出來作爲我們的圖片標籤。

在這裏插入圖片描述
在這裏插入圖片描述

2.1 Dataset class

這種方法是官方導航介紹的。
torch.utils.data.Dataset 是一個抽象類,用戶想要加載自定義的數據只需要繼承這個類,並且覆寫其中的兩個方法即可:

  • _ len _:實現len(dataset)返回整個數據集的大小。
  • __ getitem__ 用來獲取一些索引的數據,使dataset[i]返回數據集中第i個樣本。
  • 不覆寫這兩個方法會直接返回錯誤。
    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError
        

建立的自定義類如下:

#導入相關模塊
from torch.utils.data import DataLoader,Dataset
from skimage import io,transform
import matplotlib.pyplot as plt
import os
import torch
from torchvision import transforms
import numpy as np

class AnimalData(Dataset): #繼承Dataset
    def __init__(self, root_dir, transform=None): #__init__是初始化該類的一些基礎參數
        self.root_dir = root_dir   #文件目錄
        self.transform = transform #變換
        self.images = os.listdir(self.root_dir)#目錄裏的所有文件
    
    def __len__(self):#返回整個數據集的大小
        return len(self.images)
    
    def __getitem__(self,index):#根據索引index返回dataset[index]
        image_index = self.images[index]#根據索引index獲取該圖片
        img_path = os.path.join(self.root_dir, image_index)#獲取索引爲index的圖片的路徑名
        img = io.imread(img_path)# 讀取該圖片
        label = img_path.split('\\')[-1].split('.')[0]# 根據該圖片的路徑名獲取該圖片的label,具體根據路徑名進行分割。我這裏是"E:\\Python Project\\Pytorch\\dogs-vs-cats\\train\\cat.0.jpg",所以先用"\\"分割,選取最後一個爲['cat.0.jpg'],然後使用"."分割,選取[cat]作爲該圖片的標籤
        sample = {'image':img,'label':label}#根據圖片和標籤創建字典
        
        if self.transform:
            sample = self.transform(sample)#對樣本進行變換
        return sample #返回該樣本


設置好數據類之後,我們就可以將其用torch.utils.data.DataLoader加載,並訪問它。

if __name__=='__main__':
    data = AnimalData('E:/Python Project/PyTorch/dogs-vs-cats/train',transform=None)#初始化類,設置數據集所在路徑以及變換
    dataloader = DataLoader(data,batch_size=128,shuffle=True)#使用DataLoader加載數據
    for i_batch,batch_data in enumerate(dataloader):
        print(i_batch)#打印batch編號
        print(batch_data['image'].size())#打印該batch裏面圖片的大小
        print(batch_data['label'])#打印該batch裏面圖片的標籤

輸出如下:

0
torch.Size([128, 3, 224, 224])
['dog', 'cat', 'dog', 'dog', 'cat', 'dog', 'cat', 'cat', 'dog', 'dog', 'dog', 'dog', 'cat', 'cat', 'dog', 'dog', 'cat', 'dog', 'cat', 'cat', 'cat', 'cat', 'dog', 'cat', 'cat', 'cat', 'cat', 'cat', 'dog', 'dog', 'cat', 'cat', 'dog', 'cat', 'cat', 'cat', 'cat', 'dog', 'dog', 'dog', 'cat', 'dog', 'cat', 'cat', 'cat', 'dog', 'cat', 'dog', 'dog', 'dog', 'dog', 'dog', 'cat', 'cat', 'dog', 'cat', 'cat', 'cat', 'cat', 'dog', 'dog', 'cat', 'dog', 'cat', 'cat', 'cat', 'cat', 'dog', 'cat', 'dog', 'dog', 'cat', 'dog', 'dog', 'cat', 'cat', 'dog', 'dog', 'dog', 'dog', 'dog', 'cat', 'cat', 'dog', 'dog', 'cat', 'dog', 'cat', 'dog', 'cat', 'cat', 'dog', 'cat', 'dog', 'cat', 'dog', 'cat', 'dog', 'cat', 'cat', 'cat', 'dog', 'cat', 'cat', 'dog', 'dog', 'dog', 'cat', 'cat', 'cat', 'cat', 'dog', 'cat', 'dog', 'dog', 'dog', 'dog', 'cat', 'dog', 'dog', 'dog', 'dog', 'cat', 'cat', 'dog', 'cat', 'dog', 'cat']

2.2 torchvision

pytorch幾乎將上述所有工作都封裝起來供我們使用,其中一個工具就是torchvision.datasets.ImageFolder,用於加載用戶自定義的數據,要求我們的數據要有如下結構(每一類都是按照文件夾分好的):
將圖片按類別分文件夾存放。

root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png

torchvision.transforms中也封裝了各種各樣的數據處理的工具,如Resize, ToTensor等等功能供我們使用。

from torchvision import transforms,utils
from torchvision import  datasets
import torch
import matplotlib.pyplot as plt
import torch.utils.data

train_data = datasets.ImageFolder(r'E:\Python Project\PyTorch\data\hotdog\train',transform=transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
]))
print(train_data.classes)#獲取標籤
train_loader = torch.utils.data.DataLoader(train_data,batch_size=4,shuffle=True)

print(len(train_loader))
for i_batch, img in enumerate(train_loader):
    if i_batch == 0:
        print(img[1])   #標籤轉化爲編碼
        fig = plt.figure()
        grid = utils.make_grid(img[0])
        plt.imshow(grid.numpy().transpose((1, 2, 0)))
        plt.show()#顯示操作
    break


#######輸出結果#####
classes=   ['hotdog', 'not-hotdog']
class_to_idx=  {'hotdog': 0, 'not-hotdog': 1}
500
img[1]= tensor([0, 1, 0, 1])


在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章