目錄
1.可迭代對象,迭代器
首先,我們要明白python中的兩個概念:可迭代對象,迭代器。
- 可迭代對象:
- 實現了
__iter__
方法,該方法返回一個迭代器對象。
-
迭代器:
-
一個帶狀態的對象,內部持有一個狀態,該狀態用於記錄當前迭代所在的位置,以方便下次迭代的時候獲取正確的元素。
-
迭代器含有
__iter__和__next__
方法。當調用__iter__
返回迭代器自身,當調用next()
方法的時候,返回容器中的下一個值。 - 迭代器就像一個懶加載的工廠,等到有人需要的時候纔給它生成值返回,沒調用的時候就處於休眠狀態等待下一次調用。
- iter()函數:
1. 用法一:iter(callable, sentinel)
不停的調用callable,直至其的返回值等於sentinel。其中的callable可以是函數,方法或實現了__call__方法的實例。
2. 用法二:iter(collection)
1)iter()直接調用可迭代對象的__iter__(),並把__iter__()的返回結果作爲自己的返回值,故該用法常被稱爲“創建迭代器”。
2)iter函數可以顯示調用,或當執行“for i in obj:”,Python解釋器會在第一次迭代時自動調用iter(obj),之後的迭代會調用迭代器的next方法,for語句會自動處理最後拋出的StopIteration異常。
3)但iter函數獲取不到 __iter__方法時,還會調用 __getitem__方法,參數是從0開始能獲取值就是可迭代的。
2.數據集遍歷的一般化流程
for i, data in enumerate(dataLoader):
enumerate(
dataloader
)會調用dataloader
的__iter__()
方法, 產生了一個DataLoaderIter(迭代器),接着
調用DataLoaderIter
的__next__()方法
來得到batch data。 在__next__()方法方法中使用_next_index()方法獲得索引,接着通過dataset_fetcher的fetch()方法根據index
調用dataset的__getitem__()
方法, 然後用collate_fn
來把它們打包成batch。當數據讀完後,__next__()
拋出一個StopIteration
異常,for
循環結束,dataloader
失效.
3.Dataset
torch.utils.data.Dataset
是代表這一數據的抽象類(也就是基類)。我們可以通過繼承和重寫這個抽象類實現自己的數據類,只需要定義__len__
和__getitem__
這個兩個函數
如果在類中定義了__getitem__()方法,那麼實例對象(假設爲P)就可以這樣P[key]取值。當實例對象做P[key]操作時,就會調用類中的__getitem__()方法。
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
- 通過複寫 __getitem__ 方法可以通過索引index來訪問數據,能夠同時返回數據和對應的標籤(label),這裏的數據和標籤都爲tensor類型
- 通過複寫 __len__ 方法來獲取數據的個數。
比如:
class MyDataset(Dataset):
""" my dataset."""
# Initialize your data, download, etc.
def __init__(self):
# 讀取csv文件中的數據
xy = np.loadtxt('data-diabetes.csv', delimiter=',', dtype=np.float32)
self.len = xy.shape[0]
# 除去最後一列爲數據位,存在x_data中
self.x_data = torch.from_numpy(xy[:, 0:-1])
# 最後一列爲標籤爲,存在y_data中
self.y_data = torch.from_numpy(xy[:, [-1]])
def __getitem__(self, index):
# 根據索引返回數據和對應的標籤
return self.x_data[index], self.y_data[index]
def __len__(self):
# 返回文件數據的數目
return self.len
4.TensorDataset
TensorDataset是Dataset的子類,已經複寫了__len__和__getitem__方法,只需傳入張量即可。
class TensorDataset(Dataset):
"""Dataset wrapping tensors.
Each sample will be retrieved by indexing tensors along the first dimension.
Arguments:
*tensors (Tensor): tensors that have the same size of the first dimension.
"""
def __init__(self, *tensors):
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
self.tensors = tensors
def __getitem__(self, index):
return tuple(tensor[index] for tensor in self.tensors)
def __len__(self):
return self.tensors[0].size(0)
- 比如:
可以看出我們把X和Y通過Data.TensorDataset() 這個函數拼裝成了一個數據集,數據集的類型是【TensorDataset】
import torch
import torch.utils.data as Data
BATCH_SIZE = 5
x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
dataset=torch_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=0,
)
for epoch in range(3):
for step, (batch_x, batch_y) in enumerate(loader):
print('Epoch: ', epoch, '| Step: ', step, '| batch x: ', batch_x.numpy(), '| batch y: ', batch_y.numpy())
5.Dataloader
DataLoader是Pytorch中用來處理模型輸入數據的一個工具類。組合了數據集(dataset) + 採樣器(sampler),並在數據集上提供單線程或多線程(num_workers )的可迭代對象。
- 基本概念:
- epoch: 所有的訓練樣本輸入到模型中稱爲一個epoch;
- iteration: 一批樣本輸入到模型中,成爲一個Iteration;
- batchszie:批大小,決定一個epoch有多少個Iteration;
迭代次數(iteration)=樣本總數(epoch)/批尺寸(batchszie)
- 函數原型:
torch.utils.data.DataLoader(dataset, batch_size=1,
shuffle=False, sampler=None,
batch_sampler=None, num_workers=0,
collate_fn=None, pin_memory=False,
drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None)
- 參數:
-
dataset (Dataset) – 決定數據從哪讀取或者從何讀取;
-
batch_size (python:int, optional) – 批尺寸(每次訓練樣本個數,默認爲1)
-
shuffle (bool, optional) –每一個 epoch是否爲亂序 (default:
False
). -
num_workers (python:int, optional) – 是否多進程讀取數據(默認爲0);
-
drop_last (bool, optional) – 當樣本數不能被batchsize整除時,最後一批數據是否捨棄(default:
False
) -
pin_memory(bool, optional) - 如果爲True會將數據放置到GPU上去(默認爲false)
參考:
https://blog.csdn.net/u014380165/article/details/78634829
https://blog.csdn.net/zw__chen/article/details/82806900
https://www.cnblogs.com/yongjieShi/p/10456802.html
https://www.cnblogs.com/ranjiewen/p/10128046.html
Python 子類繼承父類構造函數:https://www.runoob.com/w3cnote/python-extends-init.html
https://www.ziiai.com/blog/259
Python可迭代對象,迭代器,生成器的區別:https://blog.csdn.net/jinixin/article/details/72232604
完全理解Python迭代對象、迭代器、生成器:https://foofish.net/iterators-vs-generators.html
Pytorch中的數據加載藝術:http://studyai.com/article/11efc2bf
PyTorch 數據集(Dataset):https://geek-docs.com/pytorch/pytorch-tutorial/pytorch-dataset.html