數據流
Dataset
pytorch實現了一個基類Dataset來幫助構建數據集對象。要想實現自己的數據集類需要基於torch.utils.data.Dataset
來完成,並需要在類中實現兩個方法,分別是:
__getitem__:用於返回一條數據,DATA_CLASS[index]相當於DATA_CLASS.__getitem__(index)
__len__:用於返回樣本數量,len(DATA_CLASS)相當於DATA_CLASS.__len__()
ImageFolder
ImageFolder是torch在Dataset類基礎上實現的一個常用的數據集對象。它假設所有文件都按文件夾保存,每個文件夾儲存同一類別的圖像,文件夾爲類名。如果是分類任務,則可通過調用ImageFolder類來完成數據集加載,不需要再自己寫了。
DataLoader
Dataset僅用於數據加載->返回樣本,但實際訓練時,還需要對數據進行shuffle和並行加速。DataLoader用於實現這些功能。
一些常用函數
- permute
permute用於對原數據進行維度調換,例如x.permute(1,0,2),將原本的第一維換成第0維 - squeeze和unsqueeze
squeeze用於去掉元素數量爲1的維度,例如x.squeeze()
unsqueeze用於增加維度,例如b.unsqueeze(dim)