CIFAR-10 Dataset



CIFAR-10 數據集介紹

CIFAR-10 數據集由60000張 32x32 的彩色圖片構成,其中每6000張圖片一個類別,共10個類別。其中,訓練集包含50000張圖片,測試集10000張圖片。

數據集被分割爲5個訓練 batch 和1個測試 batch。訓練 batch 包含從每個類別中抽取的 5000張圖片,測試 batch 包含從每個類別中隨機抽取的100張圖片。

CIFAR-10 數據集示例如下:

下載數據集前往官方網站 直接下載即可。


查看 CIFAR-10 數據集

官方網站 分別給出了 Python2 和 Python3 的讀取數據集代碼。下面是 Python3 讀取數據集的代碼:

def unpickle(file):
    """
    It is a function to unpickle the batch file.
    :param file: data_batch_1-5
    :return:
    """
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict


繪製 CIFAR-10 數據集內圖片

在官網上下載的數據集文件夾內,有 data_batch_1, …, data_batch_55個文件存放用於訓練的圖片。

使用 unpickle 函數可以讀取出文件內容,但是是以 numpy.ndarray 形式展現的,不是可視化的圖片。本步驟的目的,是把 ndarray 形式轉化爲可視的圖片。

因爲 CIFAR-10 數據集內是彩色圖片,所以是RGB形式的,輸入具有3個通道。將矩陣進行一個 reshape 操作,方便繪製。

轉換代碼如下:

def cifar10_plot(data, meta, im_idx=0):
    # Get the image data np.ndarray
    im = data[b'data'][im_idx, :]

    im_r = im[0:1024].reshape(32, 32)
    im_g = im[1024:2048].reshape(32, 32)
    im_b = im[2048:].reshape(32, 32)

    # 1-D arrays.shape = (N, ) ----> reshape to (1, N, 1)
    # 2-D arrays.shape = (M, N) ---> reshape to (M, N, 1)
    img = np.dstack((im_r, im_g, im_b))
    # img.shape = (32, 32, 3)

    print("shape: ", img.shape)
    print("label: ", data[b'labels'][im_idx])
    print("category:", meta[b'label_names'][data[b'labels'][im_idx]])

    plt.imshow(img)
    plt.show()


完整代碼

我們想通過命令行的格式,輸入 0-49999 內的數字,實現對訓練集圖片的任意展示。完整代碼如下;

# coding=utf8
"""
@author: Yantong Lai
@date: 08/21/2019
"""

import pickle
import matplotlib.pyplot as plt
import argparse
import numpy as np
import os

CIFAR10 = "../data/cifar-10-batches-py/"

parser = argparse.ArgumentParser("Plot training images in CIFAR10 dataset.")
parser.add_argument("-i", "--image", type=int, default=0,
                    help="Index of the image in ")
args = parser.parse_args()


def unpickle(file):
    """
    It is a function to unpickle the batch file.
    :param file: data_batch_1-5
    :return:
    """
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

def cifar10_plot(data, meta, im_idx=0):
    # Get the image data np.ndarray
    im = data[b'data'][im_idx, :]

    im_r = im[0:1024].reshape(32, 32)
    im_g = im[1024:2048].reshape(32, 32)
    im_b = im[2048:].reshape(32, 32)

    # 1-D arrays.shape = (N, ) ----> reshape to (1, N, 1)
    # 2-D arrays.shape = (M, N) ---> reshape to (M, N, 1)
    img = np.dstack((im_r, im_g, im_b))
    # img.shape = (32, 32, 3)

    print("shape: ", img.shape)
    print("label: ", data[b'labels'][im_idx])
    print("category:", meta[b'label_names'][data[b'labels'][im_idx]])

    plt.imshow(img)
    plt.show()

def main():
    batch = (args.image // 10000) + 1
    idx = args.image - (batch-1)*10000

    data = unpickle(os.path.join(CIFAR10, "data_batch_" + str(batch)))
    meta = unpickle(os.path.join(CIFAR10, "batches.meta"))

    cifar10_plot(data, meta, im_idx=idx)


if __name__ == "__main__":
    main()

其中,batch 返回的是1-5之間的整數,用以讀取 data_batch_batch文件,idx 就是該圖片在data_batch_batch 文件內的索引。

在 Terminal 中運行代碼:

$ python3 CIFAR10.py -i 23478

運行結果如圖:


總結

本文是對 CIFAR10 數據集的簡介,以及可視化代碼講解。

項目地址爲 https://github.com/icmpnorequest/Pytorch-Learning/blob/master/Python3/CIFAR10.py ,可自行前往下載完整代碼。

本人水平有限,文章或代碼有不妥之處,請給我留言或者在 Github 上提 issue,如果喜歡我的文章或代碼,請給我點贊或 star。

謝謝!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章