【深度學習】CNN + CIFAR10 學習筆記(數據輸入 mini-batch)(基於 TENSORFLOW)

前言

爲加快梯度下降收斂速度,採用了 MINI-BATCH 的方法進行數據供給,每次給予 BATCH_SIZE 項數據進行運算。

代碼

cifar10_input.py

#cifar10_input.py

import numpy as np
import cv2
import linecache

# define
#data_path = "../data/train/"
#labels_path = "../data/trainLabels.csv"

# 批量獲取 [l, r).png 圖片 
def get_X(X_path, l, r):
    result = np.zeros([r - l, 32, 32, 3])
    for i in range(l, r):
        t = i - l
        image = cv2.imread(X_path + '/%d.png' % i, cv2.IMREAD_COLOR)
        result[t] = image/255
    return result

# 批量獲取 [l, r) 標籤 
def get_labels(labels_path, l, r):
    result = np.zeros([r - l, 10])
    name = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    for i in range(l, r):
        labels = linecache.getline(labels_path, i).split(',')
        labels[1] = labels[1].strip()
        for j in range(10):
            if name[j] == labels[1]:
                result[i - l][j] = 1
                break
    return result
            
# 獲取數組 arr 中要求的數據
def get_random_X(X_path, arr):
    length = len(arr)
    result = np.zeros([length, 32, 32, 3])
    for i in range(0, length):
        image = cv2.imread(X_path + '/%d.png' % arr[i], cv2.IMREAD_COLOR)
        result[i] = image/255
    return result

# 獲取數組 arr 中要求的標籤
def get_random_labels(labels_path, arr):
    length = len(arr)
    result = np.zeros([length, 10])
    name = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    for i in range(0, length):
        labels = linecache.getline(labels_path, arr[i]).split(',')
        labels[1] = labels[1].strip()
        for j in range(10):
            if name[j] == labels[1]:
                result[i][j] = 1
                break
    return result


# 定義 cifar10 類
class cifar10(object):
    
    # 初始化
    def __init__(self, dataset_dir,labels_dir, l, r):
        self.dataset_dir = dataset_dir
        self.labels_dir = labels_dir
        self.begin = int(l)
        self.end = int(r)
    
    # 獲取下一個 batch
    def next_batch(self, batch_size):
        index = np.random.randint(self.begin, self.end, batch_size)
        next_X = get_random_X(self.dataset_dir, index)
        next_labels = get_random_labels(self.labels_dir, index);
        return next_X, next_labels, index;
        
    # 展示所有數據
    def all_data(self):
        X = get_X(self.dataset_dir, self.begin, self.end)
        y = get_labels(self.labels_dir, self.begin, self.end)
        return X, y
        
        
# 獲取 dataset 包含 train 和 test 的數據集, 其中 test 爲總數據的 10%
def read_dataset(dataset_dir, labels_dir, dataset_size):
    test_size = (0.1 * dataset_size)
    class Dataset(object):
        pass
    dataset = Dataset()
    dataset.test = cifar10(dataset_dir, labels_dir, 1, test_size + 1);
    dataset.train = cifar10(dataset_dir, labels_dir, test_size + 1, dataset_size + 1)
    return dataset

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章