DataLoader是PyTorch中的一種數據類型,在PyTorch架構中訓練或者驗證模型經常要使用它,那麼怎麼生成以及使用這樣的數據類型?
一、參數設置
torch.utils.data.DataLoader(
dataset #數據加載
batch_size = 1 #批處理樣本大小
shuffle = False #是否在每一輪epoch打亂樣本順序
sampler = None #指定數據加載中使用的索引/鍵的序列
batch_sampler = None #和sampler類似
num_workers = 0 #是否進行多進程加載數據設置
collate_fn = None #是否合併樣本列表以形成一小批Tensor
pin_memory = False #如果True,數據加載器會在返回之前將Tensors複製到CUDA固定內存
drop_last = False #True若數據集大小不能被batch_size整除,則刪除最後一個不完整的批處理。
timeout = 0 #如果爲正,則爲從工作人員收集批處理的超時值
worker_init_fn = None )
具體可參考官方文檔。
1、dataset:(數據類型 Dataset)
輸入的數據類型,也是最重要的參數,它表示要加載數據的數據集對象。
2、batch_size:(數據類型 int)
批處理樣本的大小,默認爲1。
3、shuffle:(數據類型 bool)
在每輪迭代訓練時是否將數據洗牌。默認設置爲False。設置爲True則是在每一輪中,輸入數據的順序將被打亂,這是爲了使數據更有獨立性,訓練的時候一般都設置爲True,若輸入數據是有序的,就不要設置成True了。
4、collate_fn:(數據類型 callable可調用對象)
將一小段數據合併成數據列表,默認設置是False。如果設置成True,系統會在返回前會將張量數據(Tensors)複製到CUDA內存中。
5、sampler:(數據類型 Sampler)
採樣,默認設置爲None。根據定義的策略從數據集中採樣輸入。如果定義採樣規則,則洗牌(shuffle)設置必須爲False。
6、num_workers:(數據類型 Int)
子進程數量,默認是0。使用多少個子進程來加載數據。0 就是使用主進程來加載數據。注意:這個數字必須是大於等於0的,該值的設置應該量內存大小而爲。
7、pin_memory:(數據類型 bool)
內存寄存,默認爲False。在數據返回前,是否將數據複製到CUDA內存中。
8、drop_last:(數據類型 bool)
丟棄最後數據,默認爲False。設置了 batch_size 的數目後,最後一批數據未必是設置的數目,有可能會小些。這時你是否需要丟棄這批數據。
9、timeout:(數據類型 numeric)
超時值,默認爲0。是用來設置數據讀取的超時時間,超過這個時間還沒讀取到數據的話就會報錯。 所以,數值必須大於等於0。
二、實際應用
import torch
from torch.utils.data import Dataset, DataLoader
#---------------預處理-----------------
transform = transforms.Compose([
transforms.Resize((224, 224), 2),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
#--------------數據加載----------------
trainset = torchvision.datasets.CIFAR10(root='./data',
train=True,
download=False,
transform=transform)
# torch.utils.data.DataLoader
trainloader = DataLoader(dataset=trainset,
batch_size=32,
shuffle=True,
num_workers=0)
for epoch in range(100):
running_loss = 0.0
batch_size = 32
for i, data in enumerate(trainloader, 0):
inputs, labels = data
inputs, labels = Variable(inputs), Variable(labels)