keras中提供的cifar10數據集可能因爲網速等問題無法直接下載讀取,可以進入官網下載到本地,網址:
http://www.cs.toronto.edu/~kriz/cifar.html,
這裏我們下載python版本的。
將下載的tar.gz形式的文件解壓,放到想要存放數據文件的文件夾中,這裏我的文件存放位置爲"/Users/shiruihuo/Documents/study/深度學習/data/cifar10/cifar-10-batches-py"。使用以下腳本可以正確的轉換train和test的數據及標籤。
# -*- coding: utf-8 -*-
import pickle as p
import numpy as np
import os
def load_CIFAR_batch(filename):
""" 載入cifar數據集的一個batch """
with open(filename, 'rb') as f:
datadict = p.load(f, encoding='bytes')
X = datadict[b'data']
Y = datadict[b'labels']
X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
Y = np.array(Y)
return X, Y
def load_CIFAR10(ROOT):
""" 載入cifar全部數據 """
xs = []
ys = []
for b in range(1, 6):
f = os.path.join(ROOT, 'data_batch_%d' % (b,))
X, Y = load_CIFAR_batch(f)
xs.append(X) #將所有batch整合起來
ys.append(Y)
Xtr = np.concatenate(xs) # 使變成行向量,最終Xtr的尺寸爲(50000,32,32,3)
Ytr = np.concatenate(ys)
del X, Y
Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
return Xtr, Ytr, Xte, Yte
import numpy as np
#from julyedu.data_utils import load_CIFAR10
import matplotlib.pyplot as plt
# plt.rcParams['figure.figsize'] = (10.0, 8.0)
# plt.rcParams['image.interpolation'] = 'nearest'
# plt.rcParams['image.cmap'] = 'gray'
# 載入CIFAR-10數據集
cifar10_dir = '/Users/shiruihuo/Documents/study/深度學習/data/cifar10/cifar-10-batches-py'
x_train, y_train, x_test, y_test = load_CIFAR10(cifar10_dir)
# 看看數據集中的一些樣本:每個類別展示一些
print('Training data shape: ', x_train.shape)
print('Training labels shape: ', y_train.shape)
print('Test data shape: ', x_test.shape)
print('Test labels shape: ', y_test.shape)