CIFAR-10 數據集介紹
CIFAR-10 數據集由60000張 32x32 的彩色圖片構成,其中每6000張圖片一個類別,共10個類別。其中,訓練集包含50000張圖片,測試集10000張圖片。
數據集被分割爲5個訓練 batch 和1個測試 batch。訓練 batch 包含從每個類別中抽取的 5000張圖片,測試 batch 包含從每個類別中隨機抽取的100張圖片。
CIFAR-10 數據集示例如下:
下載數據集前往官方網站 直接下載即可。
查看 CIFAR-10 數據集
官方網站 分別給出了 Python2 和 Python3 的讀取數據集代碼。下面是 Python3 讀取數據集的代碼:
def unpickle(file):
"""
It is a function to unpickle the batch file.
:param file: data_batch_1-5
:return:
"""
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
繪製 CIFAR-10 數據集內圖片
在官網上下載的數據集文件夾內,有 data_batch_1
, …, data_batch_5
5個文件存放用於訓練的圖片。
使用 unpickle
函數可以讀取出文件內容,但是是以 numpy.ndarray 形式展現的,不是可視化的圖片。本步驟的目的,是把 ndarray 形式轉化爲可視的圖片。
因爲 CIFAR-10 數據集內是彩色圖片,所以是RGB形式的,輸入具有3個通道。將矩陣進行一個 reshape
操作,方便繪製。
轉換代碼如下:
def cifar10_plot(data, meta, im_idx=0):
# Get the image data np.ndarray
im = data[b'data'][im_idx, :]
im_r = im[0:1024].reshape(32, 32)
im_g = im[1024:2048].reshape(32, 32)
im_b = im[2048:].reshape(32, 32)
# 1-D arrays.shape = (N, ) ----> reshape to (1, N, 1)
# 2-D arrays.shape = (M, N) ---> reshape to (M, N, 1)
img = np.dstack((im_r, im_g, im_b))
# img.shape = (32, 32, 3)
print("shape: ", img.shape)
print("label: ", data[b'labels'][im_idx])
print("category:", meta[b'label_names'][data[b'labels'][im_idx]])
plt.imshow(img)
plt.show()
完整代碼
我們想通過命令行的格式,輸入 0-49999 內的數字,實現對訓練集圖片的任意展示。完整代碼如下;
# coding=utf8
"""
@author: Yantong Lai
@date: 08/21/2019
"""
import pickle
import matplotlib.pyplot as plt
import argparse
import numpy as np
import os
CIFAR10 = "../data/cifar-10-batches-py/"
parser = argparse.ArgumentParser("Plot training images in CIFAR10 dataset.")
parser.add_argument("-i", "--image", type=int, default=0,
help="Index of the image in ")
args = parser.parse_args()
def unpickle(file):
"""
It is a function to unpickle the batch file.
:param file: data_batch_1-5
:return:
"""
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
def cifar10_plot(data, meta, im_idx=0):
# Get the image data np.ndarray
im = data[b'data'][im_idx, :]
im_r = im[0:1024].reshape(32, 32)
im_g = im[1024:2048].reshape(32, 32)
im_b = im[2048:].reshape(32, 32)
# 1-D arrays.shape = (N, ) ----> reshape to (1, N, 1)
# 2-D arrays.shape = (M, N) ---> reshape to (M, N, 1)
img = np.dstack((im_r, im_g, im_b))
# img.shape = (32, 32, 3)
print("shape: ", img.shape)
print("label: ", data[b'labels'][im_idx])
print("category:", meta[b'label_names'][data[b'labels'][im_idx]])
plt.imshow(img)
plt.show()
def main():
batch = (args.image // 10000) + 1
idx = args.image - (batch-1)*10000
data = unpickle(os.path.join(CIFAR10, "data_batch_" + str(batch)))
meta = unpickle(os.path.join(CIFAR10, "batches.meta"))
cifar10_plot(data, meta, im_idx=idx)
if __name__ == "__main__":
main()
其中,batch 返回的是1-5之間的整數,用以讀取 data_batch_batch
文件,idx 就是該圖片在data_batch_batch
文件內的索引。
在 Terminal 中運行代碼:
$ python3 CIFAR10.py -i 23478
運行結果如圖:
總結
本文是對 CIFAR10 數據集的簡介,以及可視化代碼講解。
項目地址爲 https://github.com/icmpnorequest/Pytorch-Learning/blob/master/Python3/CIFAR10.py ,可自行前往下載完整代碼。
本人水平有限,文章或代碼有不妥之處,請給我留言或者在 Github 上提 issue,如果喜歡我的文章或代碼,請給我點贊或 star。
謝謝!