前言
這篇博客記錄下數據集的詳細信息,具體數據集看我具體用到哪個,再查資料補充起來。
一、MNIST
MNIST 數據集來自美國國家標準與技術研究所, National Institute of Standards and Technology (NIST). 訓練集 (training set) 由來自 250 個不同人手寫的數字構成, 其中 50% 是高中學生, 50% 來自人口普查局 (the Census Bureau) 的工作人員. 測試集(test set) 也是同樣比例的手寫數字數據。
訓練數據集包含 60,000 個樣本, 測試數據集包含 10,000 樣本. 在 MNIST 數據集中的每張圖片由 28 x 28 個像素點構成, 每個像素點用一個灰度值表示。
相關代碼:
def get_mnist_data(data_folder_path, batch_size=64):
""" mnist data
Args:
train_batch_size(int): training batch size
test_batch_size(int): test batch size
Returns:
(torch.utils.data.DataLoader): train loader
(torch.utils.data.DataLoader): test loader
"""
train_data = datasets.MNIST(data_folder_path, train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
#transforms.Normalize((0.1307,), (0.3081,))
])
)
test_data = datasets.MNIST(data_folder_path, train=False, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
#transforms.Normalize((0.1307,), (0.3081,))
])
)
kwargs = {'num_workers': 4, 'pin_memory': True}
weights = [1 for data, label in train_data]
sampler = WeightedRandomSampler(weights, num_samples=5000, \
replacement=True)
train_loader = torch.utils.data.DataLoader(train_data,
batch_size=batch_size, shuffle=False,sampler=sampler, **kwargs)
test_loader = torch.utils.data.DataLoader(test_data,
batch_size=batch_size, shuffle=False, **kwargs)
return train_loader, test_loader
獲取方式
參考文獻:
[1]、http://yann.lecun.com/exdb/mnist/
[2]、https://www.cnblogs.com/xianhan/p/9145966.html
二、CIFAR-10
CIFAR-10 是由 Hinton 的學生 Alex Krizhevsky 和 Ilya Sutskever 整理的一個用於識別普適物體的小型數據集。一共包含 10 個類別的 RGB 彩色圖 片:飛機( a叩lane )、汽車( automobile )、鳥類( bird )、貓( cat )、鹿( deer )、狗( dog )、蛙類( frog )、馬( horse )、船( ship )和卡車( truck )。
CIFAR-10一共包含 60,000 個 32 × 32 大小的圖片,一共10個類別,每個類別有 6,000 個圖片。其中,訓練集有 50,000 個圖片,測試集有 10,000 個圖片。
相關代碼
def get_cifar10_data(data_folder_path, batch_size=64):
""" cifar10 data
Args:
train_batch_size(int): training batch size
test_batch_size(int): test batch size
Returns:
(torch.utils.data.DataLoader): train loader
(torch.utils.data.DataLoader): test loader
"""
transform_train = transforms.Compose([
#transforms.RandomCrop(32, padding=4),
#transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
#transforms.Normalize((0.4913, 0.4821, 0.4465), (0.2470, 0.2434, 0.2615)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
#transforms.Normalize((0.4913, 0.4821, 0.4465), (0.2470, 0.2434, 0.2615)),
])
train_data = datasets.CIFAR10(data_folder_path, train=True,
download=True, transform=transform_train)
test_data = datasets.CIFAR10(data_folder_path, train=False,
download=True, transform=transform_test)
kwargs = {'num_workers': 4, 'pin_memory': True}
weights = [1 for data, label in train_data]
sampler = WeightedRandomSampler(weights, num_samples=500, \
replacement=True)
train_loader = torch.utils.data.DataLoader(train_data,
batch_size=batch_size, shuffle=False,sampler=sampler, **kwargs)
test_loader = torch.utils.data.DataLoader(test_data,
batch_size=batch_size, shuffle=False, **kwargs)
return train_loader, test_loader
獲取方式
參考文獻
[1]、http://www.cs.toronto.edu/~kriz/cifar.html
[2]、https://www.cnblogs.com/Jerry-Dong/p/8109938.html