#在深度學習中經常需要生成帶標籤的圖片名稱列表,xxxlist.txt文件,
#編寫腳本語言,實現對文件中圖片生成帶標籤的txt文件方法
import os
def generate(dir,label):
files = os.listdir(dir)
files.sort()
print("*****************")
print("input = ",dir)
print("Start...")
listText = open(dir+'\\'+'train.txt','w')
for file in files:
fileType = os.path.split(file)
if fileType[1] == '.jpg':
continue
name = '/cat' + '/' + file + ' ' +str(int(label))+'\n'
listText.write(name)
listText.close()
print("down!")
print("********************")
if __name__ == '__main__':
generate('D:\\Spyder3Files\\data\\train\\cat',0)
import numpy as np
from skimage import io
from skimage import transform
import matplotlib.pyplot as plt
import os
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from torchvision.utils import make_grid
#定義MyDataset類, 繼承Dataset, 重寫抽象方法:__len()__, __getitem()__
class MyDataset(Dataset):
def __init__(self,root_dir,names_file,transform=None):
self.root_dir = root_dir
self.names_file = names_file
self.transform = transform
self.size = 0
self.names_list = []
if not os.path.isfile(self.names_file):
print(self.names_file + "does not exist!")
file = open(self.names_file)
for f in file:
self.names_list.append(f)
self.size += 1
def __len__(self):
return self.size
def __getitem__(self,idx):
image_path = self.root_dir + self.names_list[idx].split(' ')[0]
if not os.path.isfile(image_path):
print(image_path + "does not exist!")
return None
image = io.imread(image_path)
label = int (self.names_list[idx].split(' ')[1])
sample = {'image': image,'label' : label}
if self.transform:
sample = self.transform(sample)
return sample
train_dataset = MyDataset(root_dir = './data/train',
names_file = './data/train/train.txt',
transform = None)
print(train_dataset.size)
plt.figure()
for(cnt,i) in enumerate(train_dataset):
image = i['image']
label = i['label']
ax = plt.subplot(4,5,cnt+1)
ax.axis('off')
ax.imshow(image)
ax.set_title('label {}'.format(label))
plt.pause(0.001)
if cnt == 19:
break
# 變換Resize
class Resize(object):
def __init__(self, output_size: tuple):
self.output_size = output_size
def __call__(self, sample):
# 圖像
image = sample['image']
# 使用skitimage.transform對圖像進行縮放
image_new = transform.resize(image, self.output_size)
return {'image': image_new, 'label': sample['label']}
# 變換ToTensor
class ToTensor(object):
def __call__(self,sample):
image = sample['image']
image_new = np.transpose(image,(2,0,1))
return {'image': torch.from_numpy(image_new),
'label': sample['label']}
# 對原始的訓練數據集進行變換
transformed_trainset = MyDataset(root_dir='./data/train',
names_file='./data/train/train.txt',
transform=transforms.Compose(
[Resize((224,224)),
ToTensor()]
))
# 使用DataLoader可以利用多線程,batch,shuffle等
trainset_dataloader = DataLoader(dataset=transformed_trainset,
batch_size=4,
shuffle=True,
num_workers=0) #注意改爲主線程0
# 可視化
def show_images_batch(sample_batched):
images_batch, labels_batch = \
sample_batched['image'], sample_batched['label']
grid = make_grid(images_batch)
plt.imshow(grid.numpy().transpose(1, 2, 0))
# sample_batch: Tensor , NxCxHxW
plt.figure()
for i_batch, sample_batch in enumerate(trainset_dataloader):
show_images_batch(sample_batch)
plt.axis('off')
plt.ioff()
plt.show()
plt.show()
# 使用更簡便的方式——ImageFolder
# 如果每種類別的樣本放在各自的文件夾中,則可以直接使用ImageFolder.
import torch
from torch.utils.data import DataLoader
from torchvision import transforms,datasets
import matplotlib.pyplot as plt
import numpy as np
data_transform = transforms.Compose([
transforms.Resize((224,224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
train_dataset = datasets.ImageFolder(root = './data/train',transform = data_transform)
train_dataloader = DataLoader(dataset = train_dataset,
batch_size = 4,
shuffle = True,
num_workers = 0)
def show_batch_images(sample_batch):
images_batch = sample_batch[0]
labels_batch = sample_batch[1]
for i in range(4):
label_ = labels_batch[i].item()
image_ = np.transpose(images_batch[i],(1,2,0))
ax = plt.subplot(1,4,i + 1)
ax.imshow(image_)
ax.set_title(str(label_))
ax.axis('off')
#plt.pause(0.001) #不用多線程這裏不用考慮
plt.figure()
for i_batch,sample_batch in enumerate(train_dataloader):
show_batch_images(sample_batch)
plt.show()