torch學習記錄

數據流

Dataset

pytorch實現了一個基類Dataset來幫助構建數據集對象。要想實現自己的數據集類需要基於torch.utils.data.Dataset來完成,並需要在類中實現兩個方法,分別是:

__getitem__:用於返回一條數據,DATA_CLASS[index]相當於DATA_CLASS.__getitem__(index)
__len__:用於返回樣本數量,len(DATA_CLASS)相當於DATA_CLASS.__len__()

ImageFolder
ImageFolder是torch在Dataset類基礎上實現的一個常用的數據集對象。它假設所有文件都按文件夾保存,每個文件夾儲存同一類別的圖像,文件夾爲類名。如果是分類任務,則可通過調用ImageFolder類來完成數據集加載,不需要再自己寫了。

DataLoader

Dataset僅用於數據加載->返回樣本,但實際訓練時,還需要對數據進行shuffle和並行加速。DataLoader用於實現這些功能。

一些常用函數

  1. permute
    permute用於對原數據進行維度調換,例如x.permute(1,0,2),將原本的第一維換成第0維
  2. squeeze和unsqueeze
    squeeze用於去掉元素數量爲1的維度,例如x.squeeze()
    unsqueeze用於增加維度,例如b.unsqueeze(dim)
發佈了18 篇原創文章 · 獲贊 16 · 訪問量 2萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章