Pytorch學習(一)————定義自己的數據集

定義自己的數據集

Dataset類

torch.utils.data.Dataset 是數據集的抽象類,當我們定義自己的數據集都要繼承這個方法,並且必須覆蓋它的__len__和__getitem__這個兩個方法,__len__提高了數據集的大小,__getitem__用來索引數據集中每個樣本,

如何讀取圖像數據集,這裏不是直接將圖像放入內存,而是獲得圖像地址就可以了
具體例子:
下面展示一些 內聯代碼片

import torch.utils.data
import os
from PTL import Image
from torchvision import transforms
# 圖像預處理
#########
data_tansforms=transforms.Compose([
		transforms.ToTensor(),
		transforms.Resize((256,128),interpolation=3)
	])
# 自定義一個數據集類
#######
class Kesci_set(data.Dataset):
	//構造函數
	def__init__(self,root,transforms=None):
		//圖像路徑前面根目錄下所有圖像名稱list
		imgs=os.listdir(root) 
		//imgs=圖像根目錄+圖像名(每張圖像絕對路徑list)
		self.imgs=[os.path.join(root,k) for k in imgs]
		//圖像預處理包括(張量化,尺寸固定,旋轉,隨機裁剪)
		self.transforms=transforms
	def__getitem__(self,index):
		img_path=self.imgs[index]
		pil_img=Image.open(img_path)
		if self.transforms:
			data=self.transforms(pil_img)
		else:
			//圖像先變成python的array數組
			pil_img=np.asarray(pil_img)
			//然後將python array數組轉換爲torch的張量
			data=torch.from_numpy(pil_img)
	def __len__(self):
		return len(self.imgs)
#定義一個加載器
######
root="d:/imgdata"
image_datasets=Kesci_set(root,data_transforms)
dataloaders=data.DataLoader(image_datasets,batch_size=16,num_workers=0,shuffle=False)
#使用
#############
for data in dataloaders:
	img=data
	n,c,h,w=img.size() 
	.....

注意二點:
1、處理圖像時,必須要把讀取的圖像轉化我pytorch tensor張量形式。所以如果你添加了transforms的話一定要有transforms.ToTensor()這一句。
2、如果是用windows系統跑,設num_worker=0,windows多線程讀取數據可能會異常報錯。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章