pytorch學習(二)—自定義數據集

#在深度學習中經常需要生成帶標籤的圖片名稱列表,xxxlist.txt文件,
#編寫腳本語言,實現對文件中圖片生成帶標籤的txt文件方法
import os 
def generate(dir,label):
    files = os.listdir(dir)
    files.sort()
    print("*****************")
    print("input = ",dir)
    print("Start...")
    listText = open(dir+'\\'+'train.txt','w')
    for file in files:
        fileType = os.path.split(file)
        if fileType[1] == '.jpg':
            continue
        name = '/cat' + '/' + file + ' ' +str(int(label))+'\n'
        listText.write(name)
    listText.close()
    print("down!")
    print("********************")
if __name__ == '__main__':
    generate('D:\\Spyder3Files\\data\\train\\cat',0)
import numpy as np
from skimage import io
from skimage import transform
import matplotlib.pyplot as plt
import os
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from torchvision.utils import make_grid
#定義MyDataset類, 繼承Dataset, 重寫抽象方法:__len()__, __getitem()__
class MyDataset(Dataset):
    
    def __init__(self,root_dir,names_file,transform=None):
        self.root_dir = root_dir
        self.names_file = names_file
        self.transform = transform
        self.size = 0
        self.names_list = []
        
        if not os.path.isfile(self.names_file):
            print(self.names_file + "does not exist!")
        file = open(self.names_file)
        for f in file:
            self.names_list.append(f)
            self.size += 1
            
    def __len__(self):
        return self.size
    
    def __getitem__(self,idx):
        
        image_path = self.root_dir + self.names_list[idx].split(' ')[0]
        if not os.path.isfile(image_path):
            print(image_path + "does not exist!")
            return None
        image = io.imread(image_path)
        label = int (self.names_list[idx].split(' ')[1])
        
        sample = {'image': image,'label' : label}
        if self.transform:
            sample = self.transform(sample)
            
        return sample


train_dataset = MyDataset(root_dir = './data/train',
        names_file = './data/train/train.txt',
        transform = None)

print(train_dataset.size)
        
plt.figure()
for(cnt,i) in enumerate(train_dataset):
    image = i['image']
    label = i['label']
    
    ax = plt.subplot(4,5,cnt+1)
    ax.axis('off')
    ax.imshow(image)
    ax.set_title('label {}'.format(label))
    plt.pause(0.001)
    
    if cnt == 19:
        break

#  變換Resize    
class Resize(object):

    def __init__(self, output_size: tuple):
        self.output_size = output_size

    def __call__(self, sample):
        # 圖像
        image = sample['image']
        # 使用skitimage.transform對圖像進行縮放
        image_new = transform.resize(image, self.output_size)
        return {'image': image_new, 'label': sample['label']}

#  變換ToTensor
class ToTensor(object):
    
    def __call__(self,sample):
        image = sample['image']
        image_new = np.transpose(image,(2,0,1))
        return {'image': torch.from_numpy(image_new),
                'label': sample['label']}
        
# 對原始的訓練數據集進行變換
transformed_trainset = MyDataset(root_dir='./data/train',
                          names_file='./data/train/train.txt',
                          transform=transforms.Compose(
                              [Resize((224,224)),
                               ToTensor()]
                          ))

# 使用DataLoader可以利用多線程,batch,shuffle等
trainset_dataloader = DataLoader(dataset=transformed_trainset,
                                 batch_size=4,
                                 shuffle=True,
                                 num_workers=0)     #注意改爲主線程0

#  可視化
def show_images_batch(sample_batched):
    images_batch, labels_batch = \
    sample_batched['image'], sample_batched['label']
    grid = make_grid(images_batch)
    plt.imshow(grid.numpy().transpose(1, 2, 0))


# sample_batch:  Tensor , NxCxHxW
plt.figure()
for i_batch, sample_batch in enumerate(trainset_dataloader):
    show_images_batch(sample_batch)
    plt.axis('off')
    plt.ioff()
    plt.show()


plt.show()
#  使用更簡便的方式——ImageFolder
#  如果每種類別的樣本放在各自的文件夾中,則可以直接使用ImageFolder.
import torch 
from torch.utils.data import DataLoader
from torchvision import transforms,datasets
import matplotlib.pyplot as plt
import numpy as np

data_transform = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        ])

train_dataset = datasets.ImageFolder(root = './data/train',transform = data_transform)

train_dataloader = DataLoader(dataset = train_dataset,
                              batch_size = 4,
                              shuffle = True,
                              num_workers = 0)

def show_batch_images(sample_batch):
    images_batch = sample_batch[0]
    labels_batch = sample_batch[1] 
    for i in range(4):
        label_ = labels_batch[i].item()
        image_ = np.transpose(images_batch[i],(1,2,0))
        ax = plt.subplot(1,4,i + 1)      
        ax.imshow(image_)
        ax.set_title(str(label_)) 
        ax.axis('off')
        #plt.pause(0.001)   #不用多線程這裏不用考慮

        
plt.figure()
for i_batch,sample_batch in enumerate(train_dataloader):
    show_batch_images(sample_batch)  
    
    plt.show()

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章