PyTorch中torch.utils.data.DataLoader加載數據

torch.utils.data.DataLoader使用方法

       DataLoader是PyTorch中的一種數據類型,在PyTorch架構中訓練或者驗證模型經常要使用它,那麼怎麼生成以及使用這樣的數據類型?
在這裏插入圖片描述


一、參數設置

torch.utils.data.DataLoader(
      dataset   			#數據加載
      batch_size = 1		#批處理樣本大小
      shuffle = False		#是否在每一輪epoch打亂樣本順序
      sampler = None		#指定數據加載中使用的索引/鍵的序列
      batch_sampler = None	#和sampler類似
      num_workers = 0		#是否進行多進程加載數據設置
      collat​​e_fn = None		#是否合併樣本列表以形成一小批Tensor
      pin_memory = False	#如果True,數據加載器會在返回之前將Tensors複製到CUDA固定內存
      drop_last = False		#True若數據集大小不能被batch_size整除,則刪除最後一個不完整的批處理。
      timeout = 0			#如果爲正,則爲從工作人員收集批處理的超時值
      worker_init_fn = None )

       具體可參考官方文檔

1、dataset:(數據類型 Dataset)
       輸入的數據類型,也是最重要的參數,它表示要加載數據的數據集對象。

2、batch_size:(數據類型 int)
       批處理樣本的大小,默認爲1。

3、shuffle:(數據類型 bool)
       在每輪迭代訓練時是否將數據洗牌。默認設置爲False。設置爲True則是在每一輪中,輸入數據的順序將被打亂,這是爲了使數據更有獨立性,訓練的時候一般都設置爲True,若輸入數據是有序的,就不要設置成True了。

4、collate_fn:(數據類型 callable可調用對象)
       將一小段數據合併成數據列表,默認設置是False。如果設置成True,系統會在返回前會將張量數據(Tensors)複製到CUDA內存中。

5、sampler:(數據類型 Sampler)
       採樣,默認設置爲None。根據定義的策略從數據集中採樣輸入。如果定義採樣規則,則洗牌(shuffle)設置必須爲False。

6、num_workers:(數據類型 Int)
       子進程數量,默認是0。使用多少個子進程來加載數據。0 就是使用主進程來加載數據。注意:這個數字必須是大於等於0的,該值的設置應該量內存大小而爲

7、pin_memory:(數據類型 bool)
       內存寄存,默認爲False。在數據返回前,是否將數據複製到CUDA內存中。

8、drop_last:(數據類型 bool)
       丟棄最後數據,默認爲False。設置了 batch_size 的數目後,最後一批數據未必是設置的數目,有可能會小些。這時你是否需要丟棄這批數據。

9、timeout:(數據類型 numeric)
       超時值,默認爲0。是用來設置數據讀取的超時時間,超過這個時間還沒讀取到數據的話就會報錯。 所以,數值必須大於等於0。


二、實際應用

import torch
from torch.utils.data import Dataset, DataLoader

#---------------預處理-----------------
transform = transforms.Compose([
    	transforms.Resize((224, 224), 2),
    	transforms.ToTensor(),
    	transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
#--------------數據加載----------------
trainset = torchvision.datasets.CIFAR10(root='./data', 
										train=True, 
										download=False, 
										transform=transform)
# torch.utils.data.DataLoader
trainloader = DataLoader(dataset=trainset, 
							batch_size=32, 
							shuffle=True, 
							num_workers=0)  

for epoch in range(100):
	running_loss = 0.0
	batch_size = 32
	for i, data in enumerate(trainloader, 0):
    	inputs, labels = data
        inputs, labels = Variable(inputs), Variable(labels)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章