定義自己的數據集
Dataset類
torch.utils.data.Dataset 是數據集的抽象類,當我們定義自己的數據集都要繼承這個方法,並且必須覆蓋它的__len__和__getitem__這個兩個方法,__len__提高了數據集的大小,__getitem__用來索引數據集中每個樣本,
如何讀取圖像數據集,這裏不是直接將圖像放入內存,而是獲得圖像地址就可以了
具體例子:
下面展示一些 內聯代碼片
。
import torch.utils.data
import os
from PTL import Image
from torchvision import transforms
# 圖像預處理
#########
data_tansforms=transforms.Compose([
transforms.ToTensor(),
transforms.Resize((256,128),interpolation=3)
])
# 自定義一個數據集類
#######
class Kesci_set(data.Dataset):
//構造函數
def__init__(self,root,transforms=None):
//圖像路徑前面根目錄下所有圖像名稱list
imgs=os.listdir(root)
//imgs=圖像根目錄+圖像名(每張圖像絕對路徑list)
self.imgs=[os.path.join(root,k) for k in imgs]
//圖像預處理包括(張量化,尺寸固定,旋轉,隨機裁剪)
self.transforms=transforms
def__getitem__(self,index):
img_path=self.imgs[index]
pil_img=Image.open(img_path)
if self.transforms:
data=self.transforms(pil_img)
else:
//圖像先變成python的array數組
pil_img=np.asarray(pil_img)
//然後將python array數組轉換爲torch的張量
data=torch.from_numpy(pil_img)
def __len__(self):
return len(self.imgs)
#定義一個加載器
######
root="d:/imgdata"
image_datasets=Kesci_set(root,data_transforms)
dataloaders=data.DataLoader(image_datasets,batch_size=16,num_workers=0,shuffle=False)
#使用
#############
for data in dataloaders:
img=data
n,c,h,w=img.size()
.....
注意二點:
1、處理圖像時,必須要把讀取的圖像轉化我pytorch tensor張量形式。所以如果你添加了transforms的話一定要有transforms.ToTensor()這一句。
2、如果是用windows系統跑,設num_worker=0,windows多線程讀取數據可能會異常報錯。