自定義製作python版本的CIFAR數據集
1、準備圖像
(以製作小數據集爲例,便於理解)
這裏自定義製作的數據集只包含2個類:dog,parrot,每個類有121張圖像。數據集共有242張圖像,測試圖像30張,訓練圖像212張。將數據集分爲1個測試批次和2個訓練批次。測試批次包含每個類的15張圖像。每個訓練批次包含106張圖像,但是其中屬於各個類的圖像數量隨機(即不同訓練批次中相同類的圖像數量不一定相等)。
圖片的命名規則爲 “label_類別名_編號.jpg”,這裏規定,label爲0時類別名爲dog,label爲1時類別名爲parrot。
2、數據集理解
首先調整所有圖像的大小,這裏調整爲256×256(img_dim=256)。
def img_resize(img_dir, img_dim):
'''Args:
img_dir: 該批次圖像文件夾路徑
img_dim: 調整後的大小
'''
img_resized_dir = img_dir + '_resize' # 調整後圖像的保存路徑
os.makedirs(img_resized_dir, exist_ok=True)
img_list = os.listdir(img_dir)
for img_name in img_list:
img_path = os.path.join(img_dir, img_name)
img = Image.open(img_path)
x_new = img_dim
y_new = img_dim
out = img.resize((x_new, y_new), Image.ANTIALIAS)
out.save('{}/{}.jpg'.format(img_resized_dir, img_name))
print('Images in {} are resized as {}×{}.\n'.format(img_dir, img_dim, img_dim))
return img_resized_dir
cifar數據集中每個批次文件包含一個字典,字典內有4個鍵,分別是:'batch_label','data','filenames','labels'。可以使用以下代碼查看。
- 'batch_label' = 當前批次的名字。
- 'data' = 形狀爲(106,256×256×3)的uint8的numpy數組。數組的每行存儲一張圖像的數字信息,按通道順序爲紅、綠、藍存儲,每個通道按行優先。
- 'filenames' = 一個包含該批次所有圖像名稱的列表,長度爲106。
- 'labels' = 一個取值爲0、1的列表,長度爲106。索引i處的數字爲第i個圖像的標籤。標籤0表示dog,標籤1表示parrot。
def unpickle(file):
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='latin-1')
return dict
cc = unpickle("images_cifar/data_batch_1")
print(cc.keys())
print(cc['filenames'])
3、完整運行代碼
自定義的圖像批次保存在images文件夾中,生成的cifar數據集文件保存在images_cifar文件夾中。
from PIL import Image
from numpy import *
import numpy as np
import os
import pickle
def img_resize(img_dir, img_dim):
'''Args:
img_dir: 該批次圖像文件夾路徑
img_dim: 調整後的大小
'''
img_resized_dir = img_dir + '_resize' # 調整後圖像的保存路徑
os.makedirs(img_resized_dir, exist_ok=True)
img_list = os.listdir(img_dir)
for img_name in img_list:
img_path = os.path.join(img_dir, img_name)
img = Image.open(img_path)
x_new = img_dim
y_new = img_dim
out = img.resize((x_new, y_new), Image.ANTIALIAS)
out.save('{}/{}.jpg'.format(img_resized_dir, img_name))
print('Images in {} are resized as {}×{}.\n'.format(img_dir, img_dim, img_dim))
return img_resized_dir
def get_filenames_and_labels(img_resized_dir):
filenames = []
labels = []
img_list = os.listdir(img_resized_dir)
for img_name in img_list:
filenames.append(img_name.encode('utf-8'))
img_name_str = img_name.split('.')[0]
label = int(img_name_str.split('_')[0])
labels.append(label)
return filenames, labels
def get_img_data(img_resized_dir):
imgs = []
# count = 0
img_list = os.listdir(img_resized_dir)
for img_name in img_list:
img_path = os.path.join(img_resized_dir, img_name)
img = Image.open(img_path)
r, g, b = img.split()
r_array = np.array(r, dtype=np.uint8).flatten()
g_array = np.array(g, dtype=np.uint8).flatten()
b_array = np.array(b, dtype=np.uint8).flatten()
img_array = concatenate((r_array, g_array, b_array))
# print(img_array.shape)
imgs.append(img_array)
# count += 1
# print('Get {} images of {}'.format(count, img_resized_dir))
imgs = np.array(imgs, dtype=np.uint8)
return imgs
if __name__ == '__main__':
img_dir_names = ['test_batch'] # 1個測試批次
num_data_batch = 2 # 2個訓練批次
for i in range(1, num_data_batch + 1):
img_dir_names.append('data_batch_' + str(i))
count = 0
for img_dir_name in img_dir_names:
img_dir = 'images/' + img_dir_name
filepath = 'images_cifar/' + img_dir_name
img_resized_dir = img_resize(img_dir, img_dim=256)
data_batch = {}
if 'test' in filepath:
data_batch['batch_label'.encode('utf-8')] = 'testing batch 1 of 1'.encode('utf-8')
else:
count += 1
batch_label = 'training batch ' + str(count) + ' of ' + str(num_data_batch)
data_batch['batch_label'.encode('utf-8')] = batch_label.encode('utf-8')
filenames, labels = get_filenames_and_labels(img_resized_dir)
data = get_img_data(img_resized_dir)
data_batch['filenames'.encode('utf-8')] = filenames
data_batch['labels'.encode('utf-8')] = labels
data_batch['data'.encode('utf-8')] = data
with open(filepath, 'wb') as f:
pickle.dump(data_batch, f)
img_classes = 'images_cifar/batches.meta'
label_names = {0: 'dog', 1: 'parrot'}
with open(img_classes, 'wb') as f:
pickle.dump(label_names, f)
# def unpickle(file):
# import pickle
# with open(file, 'rb') as fo:
# dict = pickle.load(fo, encoding='latin-1')
# return dict
#
#
# cc = unpickle("C:/Users/lenovo/.keras/datasets/cifar-10-batches-py/data_batch_1")
# print(cc.keys())
# print(cc['filenames'])