数据集详解(待更新)

前言

这篇博客记录下数据集的详细信息,具体数据集看我具体用到哪个,再查资料补充起来。

一、MNIST

MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据。
训练数据集包含 60,000样本, 测试数据集包含 10,000 样本. 在 MNIST 数据集中的每张图片由 28 x 28 个像素点构成, 每个像素点用一个灰度值表示。
在这里插入图片描述

相关代码:

def get_mnist_data(data_folder_path, batch_size=64):
    """ mnist data
    Args:
        train_batch_size(int): training batch size 
        test_batch_size(int): test batch size
    Returns:
        (torch.utils.data.DataLoader): train loader 
        (torch.utils.data.DataLoader): test loader
    """
    train_data = datasets.MNIST(data_folder_path, train=True,  download=True, 
        transform=transforms.Compose([
            transforms.ToTensor(), 
            #transforms.Normalize((0.1307,), (0.3081,))
            ])
        )

    test_data  = datasets.MNIST(data_folder_path, train=False, download=True, 
        transform=transforms.Compose([
            transforms.ToTensor(), 
            #transforms.Normalize((0.1307,), (0.3081,))
            ])
        )

    kwargs = {'num_workers': 4, 'pin_memory': True}

    weights = [1 for data, label in train_data]
    sampler = WeightedRandomSampler(weights, num_samples=5000, \
                                    replacement=True)
    train_loader = torch.utils.data.DataLoader(train_data, 
        batch_size=batch_size, shuffle=False,sampler=sampler, **kwargs)
    test_loader  = torch.utils.data.DataLoader(test_data,  
        batch_size=batch_size, shuffle=False, **kwargs)

    return train_loader, test_loader

获取方式

详见这篇

参考文献:

[1]、http://yann.lecun.com/exdb/mnist/
[2]、https://www.cnblogs.com/xianhan/p/9145966.html

二、CIFAR-10

CIFAR-10 是由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。一共包含 10 个类别的 RGB 彩色图 片:飞机( a叩lane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。
CIFAR-10一共包含 60,00032 × 32 大小的图片,一共10个类别,每个类别有 6,000图片。其中,训练集有 50,000图片,测试集有 10,000图片。
在这里插入图片描述

相关代码

def get_cifar10_data(data_folder_path, batch_size=64):
    """ cifar10 data
    Args:
        train_batch_size(int): training batch size 
        test_batch_size(int): test batch size
    Returns:
        (torch.utils.data.DataLoader): train loader 
        (torch.utils.data.DataLoader): test loader
    """
    transform_train = transforms.Compose([

        #transforms.RandomCrop(32, padding=4),
        #transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        #transforms.Normalize((0.4913, 0.4821, 0.4465), (0.2470, 0.2434, 0.2615)),

    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        #transforms.Normalize((0.4913, 0.4821, 0.4465), (0.2470, 0.2434, 0.2615)),
    ])

    train_data = datasets.CIFAR10(data_folder_path, train=True, 
        download=True, transform=transform_train)
    test_data  = datasets.CIFAR10(data_folder_path, train=False, 
        download=True, transform=transform_test) 

    kwargs = {'num_workers': 4, 'pin_memory': True}
    weights = [1 for data, label in train_data]
    sampler = WeightedRandomSampler(weights, num_samples=500, \
                                    replacement=True)
    train_loader = torch.utils.data.DataLoader(train_data,
        batch_size=batch_size, shuffle=False,sampler=sampler, **kwargs)
    test_loader  = torch.utils.data.DataLoader(test_data,
        batch_size=batch_size, shuffle=False, **kwargs)

    return train_loader, test_loader

获取方式

详见这篇

参考文献

[1]、http://www.cs.toronto.edu/~kriz/cifar.html
[2]、https://www.cnblogs.com/Jerry-Dong/p/8109938.html

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章