前言
为加快梯度下降收敛速度,采用了 MINI-BATCH 的方法进行数据供给,每次给予 BATCH_SIZE 项数据进行运算。
代码
cifar10_input.py
#cifar10_input.py
import numpy as np
import cv2
import linecache
# define
#data_path = "../data/train/"
#labels_path = "../data/trainLabels.csv"
# 批量获取 [l, r).png 图片
def get_X(X_path, l, r):
result = np.zeros([r - l, 32, 32, 3])
for i in range(l, r):
t = i - l
image = cv2.imread(X_path + '/%d.png' % i, cv2.IMREAD_COLOR)
result[t] = image/255
return result
# 批量获取 [l, r) 标签
def get_labels(labels_path, l, r):
result = np.zeros([r - l, 10])
name = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
for i in range(l, r):
labels = linecache.getline(labels_path, i).split(',')
labels[1] = labels[1].strip()
for j in range(10):
if name[j] == labels[1]:
result[i - l][j] = 1
break
return result
# 获取数组 arr 中要求的数据
def get_random_X(X_path, arr):
length = len(arr)
result = np.zeros([length, 32, 32, 3])
for i in range(0, length):
image = cv2.imread(X_path + '/%d.png' % arr[i], cv2.IMREAD_COLOR)
result[i] = image/255
return result
# 获取数组 arr 中要求的标签
def get_random_labels(labels_path, arr):
length = len(arr)
result = np.zeros([length, 10])
name = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
for i in range(0, length):
labels = linecache.getline(labels_path, arr[i]).split(',')
labels[1] = labels[1].strip()
for j in range(10):
if name[j] == labels[1]:
result[i][j] = 1
break
return result
# 定义 cifar10 类
class cifar10(object):
# 初始化
def __init__(self, dataset_dir,labels_dir, l, r):
self.dataset_dir = dataset_dir
self.labels_dir = labels_dir
self.begin = int(l)
self.end = int(r)
# 获取下一个 batch
def next_batch(self, batch_size):
index = np.random.randint(self.begin, self.end, batch_size)
next_X = get_random_X(self.dataset_dir, index)
next_labels = get_random_labels(self.labels_dir, index);
return next_X, next_labels, index;
# 展示所有数据
def all_data(self):
X = get_X(self.dataset_dir, self.begin, self.end)
y = get_labels(self.labels_dir, self.begin, self.end)
return X, y
# 获取 dataset 包含 train 和 test 的数据集, 其中 test 为总数据的 10%
def read_dataset(dataset_dir, labels_dir, dataset_size):
test_size = (0.1 * dataset_size)
class Dataset(object):
pass
dataset = Dataset()
dataset.test = cifar10(dataset_dir, labels_dir, 1, test_size + 1);
dataset.train = cifar10(dataset_dir, labels_dir, test_size + 1, dataset_size + 1)
return dataset