pytorch加載數據

參考:PyTorch深度學習快速入門教程(絕對通俗易懂!)【小土堆】
本文是上面視頻的筆記,up主講的特別詳細,推薦觀看。
在pytorch中加載數據主要涉及到兩個類:Dataset 和 Dataloader
Dataset :提供一種方式去提取數據並得到label
Dataset:對數據進行打包送到網絡中去,爲後面的網絡提供不同的數據形式。
下面是代碼及說明:

from torch.utils. data import Dataset

在這裏插入圖片描述
可看到說明,Dataset是一個抽象類,我們重寫Dataset時要繼承這個類,所有的子類都應該重寫__getitem__()方法,這個方法作用是獲取數據及對應的labe。同時我們可以選擇性地去重寫__len__方法,其作用是獲取數據集長度。

例子:

這裏我使用的是貓狗二分類的數據集,如圖:
在這裏插入圖片描述

from torch.utils. data import Dataset
from PIL import Image
import os

class Mydataset(Dataset):
    def __init__(self,root_dir, label_dir):
        self.root_dir = root_dir    ##根目錄
        self.label_dir = label_dir  ##標籤,也就是文件名
        self.path = os.path.join(self.root_dir,self.label_dir)  ##拼成一個完整的目錄
        self.img_path = os.listdir(self.path) ##獲得圖片的一個list

    def __getitem__(self, idx):
        img_name = self.img_path[idx]  ##得到單個圖片的名字
        img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)  ##得到單個圖片的路徑
        img = Image.open(img_item_path) ##圖片數據
        label = self.label_dir ##標籤
        return img, label
    def __len__(self):
        return len(self.img_path)

root_dir="D:/貓狗大戰/data/train"
cat_label_dir = "cat"
dog_label_dir = "dog"
cat_dataset = Mydataset(root_dir,cat_label_dir)
dog_dataset = Mydataset(root_dir,dog_label_dir)
img, label = cat_dataset[1]
img.show()
print(label)

img, label = dog_dataset[1]
img.show()
print(label)

輸出結果:
cat
dog
在這裏插入圖片描述
寫給自己,另外,可以參考這篇博客:
https://ptorch.com/news/215.html
fastai也可以關注以下

發佈了61 篇原創文章 · 獲贊 19 · 訪問量 6111
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章