參考:PyTorch深度學習快速入門教程(絕對通俗易懂!)【小土堆】
本文是上面視頻的筆記,up主講的特別詳細,推薦觀看。
在pytorch中加載數據主要涉及到兩個類:Dataset 和 Dataloader
Dataset :提供一種方式去提取數據並得到label
Dataset:對數據進行打包送到網絡中去,爲後面的網絡提供不同的數據形式。
下面是代碼及說明:
from torch.utils. data import Dataset
可看到說明,Dataset是一個抽象類,我們重寫Dataset時要繼承這個類,所有的子類都應該重寫__getitem__()方法,這個方法作用是獲取數據及對應的labe。同時我們可以選擇性地去重寫__len__方法,其作用是獲取數據集長度。
例子:
這裏我使用的是貓狗二分類的數據集,如圖:
from torch.utils. data import Dataset
from PIL import Image
import os
class Mydataset(Dataset):
def __init__(self,root_dir, label_dir):
self.root_dir = root_dir ##根目錄
self.label_dir = label_dir ##標籤,也就是文件名
self.path = os.path.join(self.root_dir,self.label_dir) ##拼成一個完整的目錄
self.img_path = os.listdir(self.path) ##獲得圖片的一個list
def __getitem__(self, idx):
img_name = self.img_path[idx] ##得到單個圖片的名字
img_item_path = os.path.join(self.root_dir,self.label_dir,img_name) ##得到單個圖片的路徑
img = Image.open(img_item_path) ##圖片數據
label = self.label_dir ##標籤
return img, label
def __len__(self):
return len(self.img_path)
root_dir="D:/貓狗大戰/data/train"
cat_label_dir = "cat"
dog_label_dir = "dog"
cat_dataset = Mydataset(root_dir,cat_label_dir)
dog_dataset = Mydataset(root_dir,dog_label_dir)
img, label = cat_dataset[1]
img.show()
print(label)
img, label = dog_dataset[1]
img.show()
print(label)
輸出結果:
cat
dog
寫給自己,另外,可以參考這篇博客:
https://ptorch.com/news/215.html
fastai也可以關注以下