數據集詳解(待更新)

前言

這篇博客記錄下數據集的詳細信息,具體數據集看我具體用到哪個,再查資料補充起來。

一、MNIST

MNIST 數據集來自美國國家標準與技術研究所, National Institute of Standards and Technology (NIST). 訓練集 (training set) 由來自 250 個不同人手寫的數字構成, 其中 50% 是高中學生, 50% 來自人口普查局 (the Census Bureau) 的工作人員. 測試集(test set) 也是同樣比例的手寫數字數據。
訓練數據集包含 60,000樣本, 測試數據集包含 10,000 樣本. 在 MNIST 數據集中的每張圖片由 28 x 28 個像素點構成, 每個像素點用一個灰度值表示。
在這裏插入圖片描述

相關代碼:

def get_mnist_data(data_folder_path, batch_size=64):
    """ mnist data
    Args:
        train_batch_size(int): training batch size 
        test_batch_size(int): test batch size
    Returns:
        (torch.utils.data.DataLoader): train loader 
        (torch.utils.data.DataLoader): test loader
    """
    train_data = datasets.MNIST(data_folder_path, train=True,  download=True, 
        transform=transforms.Compose([
            transforms.ToTensor(), 
            #transforms.Normalize((0.1307,), (0.3081,))
            ])
        )

    test_data  = datasets.MNIST(data_folder_path, train=False, download=True, 
        transform=transforms.Compose([
            transforms.ToTensor(), 
            #transforms.Normalize((0.1307,), (0.3081,))
            ])
        )

    kwargs = {'num_workers': 4, 'pin_memory': True}

    weights = [1 for data, label in train_data]
    sampler = WeightedRandomSampler(weights, num_samples=5000, \
                                    replacement=True)
    train_loader = torch.utils.data.DataLoader(train_data, 
        batch_size=batch_size, shuffle=False,sampler=sampler, **kwargs)
    test_loader  = torch.utils.data.DataLoader(test_data,  
        batch_size=batch_size, shuffle=False, **kwargs)

    return train_loader, test_loader

獲取方式

詳見這篇

參考文獻:

[1]、http://yann.lecun.com/exdb/mnist/
[2]、https://www.cnblogs.com/xianhan/p/9145966.html

二、CIFAR-10

CIFAR-10 是由 Hinton 的學生 Alex Krizhevsky 和 Ilya Sutskever 整理的一個用於識別普適物體的小型數據集。一共包含 10 個類別的 RGB 彩色圖 片:飛機( a叩lane )、汽車( automobile )、鳥類( bird )、貓( cat )、鹿( deer )、狗( dog )、蛙類( frog )、馬( horse )、船( ship )和卡車( truck )。
CIFAR-10一共包含 60,00032 × 32 大小的圖片,一共10個類別,每個類別有 6,000圖片。其中,訓練集有 50,000圖片,測試集有 10,000圖片。
在這裏插入圖片描述

相關代碼

def get_cifar10_data(data_folder_path, batch_size=64):
    """ cifar10 data
    Args:
        train_batch_size(int): training batch size 
        test_batch_size(int): test batch size
    Returns:
        (torch.utils.data.DataLoader): train loader 
        (torch.utils.data.DataLoader): test loader
    """
    transform_train = transforms.Compose([

        #transforms.RandomCrop(32, padding=4),
        #transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        #transforms.Normalize((0.4913, 0.4821, 0.4465), (0.2470, 0.2434, 0.2615)),

    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        #transforms.Normalize((0.4913, 0.4821, 0.4465), (0.2470, 0.2434, 0.2615)),
    ])

    train_data = datasets.CIFAR10(data_folder_path, train=True, 
        download=True, transform=transform_train)
    test_data  = datasets.CIFAR10(data_folder_path, train=False, 
        download=True, transform=transform_test) 

    kwargs = {'num_workers': 4, 'pin_memory': True}
    weights = [1 for data, label in train_data]
    sampler = WeightedRandomSampler(weights, num_samples=500, \
                                    replacement=True)
    train_loader = torch.utils.data.DataLoader(train_data,
        batch_size=batch_size, shuffle=False,sampler=sampler, **kwargs)
    test_loader  = torch.utils.data.DataLoader(test_data,
        batch_size=batch_size, shuffle=False, **kwargs)

    return train_loader, test_loader

獲取方式

詳見這篇

參考文獻

[1]、http://www.cs.toronto.edu/~kriz/cifar.html
[2]、https://www.cnblogs.com/Jerry-Dong/p/8109938.html

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章