1.簡介
雖然Pytorch-Geometric提供了很多官方數據集,但是當需要構建自己的數據集的時候,就需要對如何使用dataset
基類構造自己的數據集有所瞭解。庫中提供了兩個構建數據集的基類:torch_geometric.data.Dataset
和torch_geometric.data.InMemoryDataset
,其中torch_geometric.data.InMemoryDataset
繼承了torch_geometric.data.Dataset
,表示是否將整個數據集加載到內存中。
根據torchvision
的習慣,每一個數據集都需要指定一個根目錄,根目錄下面需要分爲兩個文件夾,一個是raw_dir
,這個表示下載的原始數據的存放位置,另一個是processed_dir
,表示處理後的數據集存放位置。
另外,每一個數據集函數都可以傳遞函數transform
,pre_transform
和pre_filter
,默認爲None
。transform
函數用於數據對象被加載使用之前進行的動態轉換(一般用於數據增強
);pre_transform
函數將數據對象保存到磁盤以前進行的轉換,也就是得到processed_dir
內數據文件之前對其調用(一般用於只需要計算一次的複雜預處理過程);pre_filter
函數在數據進行保存之前進行過濾。
2.創建一次讀入內存的數據
構建torch_geometric.data.InMemoryDataset
,需要重寫(區分重載和重寫)四個函數:
(1)torch_geometric.data.InMemoryDataset.raw_file_names()
存放raw_dir
目錄下所有數據文件名的字符串列表,用於下載時的檢查過程(正如之前的文章提到的,數據集下載的時候會檢測是否已經存在,避免重複下載,也就是如何避免自動下載的httperror
的解決方案)。
(2)torch_geometric.data.InMemoryDataset.processed_file_names()
和(1)類似,存放processed_dir
目錄下的文件名的列表,用於檢測是否已經存在(不會二次處理)。
(3)torch_geometric.data.InMemoryDataset.download()
下載數據到raw_dir
目錄下。
(4)torch_geometric.data.InMemoryDataset.process()
對raw_dir
下的數據進行處理並存儲到processed_dir
目錄下。
因此,可以發現關鍵在於第四個函數的實現,函數內首先需要讀取原始數據並創建一個torch_geometric.data.Data
對象的列表,並存儲到processed_dir
目錄下面。直接存儲和使用這個python-list
時間代價很高,所以在存儲之前調用torch_geometric.data.InMemoryDataset.collate()
函數將列表轉換爲一個torch_geometric.data.Data
對象。處理後的數據被整合到了一個數據對象中(作爲返回值),同時返回一個slices
字典來獲取到這個數據對象中單個數據,所以總結下來process
過程一共分四步:
- 加載數據創建列表
- 進行各種處理過程
- 調用collate()函數
- 存儲本地
最後在數據類的構造函數中加載數據集並賦值給self.data
和self.slices
。
import torch
from torch_geometric.data import InMemoryDataset
class MyDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None):
# 數據的下載和處理過程在父類中調用實現
super(MyDataset, self).__init__(root, transform, pre_transform)
# 加載數據
self.data, self.slices = torch.load(self.processed_paths[0])
# 將函數修飾爲類屬性
@property
def raw_file_names(self):
return ['file_1', 'file_2']
@property
def processed_file_names(self):
return ['data.pt']
def download(self):
# download to self.raw_dir
pass
def process(self):
data_list = [...]
if self.pre_filter is not None:
data_list = [data for data in data_list if self.pre_filter(data)]
if self.pre_filter is not None:
data_list = [self.pre_transform(data) for data in data_list]
data, slices = self.collate(data_list)
# 這裏的save方式以及路徑需要對應構造函數中的load操作
torch.save((data, slices), self.processed_paths[0])
3.創建大規模數據
大數據集一般不會直接加載到內存中,這裏構建數據集的時候需要繼承父類torch_geometric.data.Dataset
。在上面構建數據集時,重寫了四個函數,此處還需要多實現兩個函數:
(1)torch_geometric.data.Dataset.len()
返回數據集的文件個數。
(2)torch_geometric.data.Dataset.get()
實現對單個數據(圖數據集的話一般是單個圖)的加載邏輯。
import os.path as osp
import torch
# 這裏就不能用InMemoryDataset了
from torch_geometric.data import Dataset
class MyDataset(Dataset):
# 默認預處理函數的參數都是None
def __init__(self, root, transform=None, pre_transform=None):
super(MyDataset, self).__init__(root, transform, pre_transform)
@property
def raw_file_names(self):
return ['file_1', 'file_2']
@property
def processed_file_names(self):
# 一次無法加載所有數據,所以對數據進行了分解
return ['data1.pt', 'data2.pt', 'data3.pt']
def download(self):
# Download to raw_dir
pass
def process(self):
i = 0
# 遍歷每一個文件路徑
for raw_path in self.raw_paths:
data = Data(...)
if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)
torch.save(data, osp.join(self.processed_dir, 'data_{}.pt'.format(i)))
i += 1
def len(self):
return len(self.processed_file_names)
def get(self, idx):
data = torch.load(osp.join(self.processed_dir, 'data{}.pt',format(idx)))
return data
4.相關細節
當我第一遍看完文檔之後,心中還是存在很多疑惑的,第一,畢竟直接繼承了一個父類,具體的流程是如何的,還不清楚,第二,沒有親自制作一個數據集,的確理解上存在模糊,下面對我個人的一些疑惑進行探索。
4.1具體的數據加載流程
這裏的流程是指包括了個人定義的數據類內部的邏輯以及父類InMemoryDataset
中的邏輯(先分析內存數據集):
1.對MyDataset
實例化,此時調用類內構造函數__init__
,先通過父類構造函數,再從本地加載數據,因此所有的關鍵操作都是在父類構造中發生的。
2.(在調用父類構造函數的時候,根據文檔的官方例子我產生了兩個疑惑,第一個是參數中沒有傳遞pre_filter
參數,但是後面爲什麼還要判斷self.pre_filter
,難道說默認的pre_filter
不是None
?而是父類中給了一個實現方式?第二個是參數中傳遞了transform
,但是在重寫的process
函數並沒有transform
的過程,那麼這個過程又是在哪裏實現的呢?)在InMemoryDataset
類中,構造函數爲:
def __init__(self, root=None, transform=None, pre_transform=None,
pre_filter=None):
super(InMemoryDataset, self).__init__(root, transform, pre_transform,
pre_filter)
self.data, self.slices = None, None
其中transform
、pre_transform
和pre_filter
都是函數句柄(callable),具體說明如下:
(1)transform
接受參數類型爲torch_geometric.data.Data
並且返回一個轉換後的版本(數據類型不變),在每一次數據加載到程序之前都會默認調用進行數據轉換。
(2)pre_transform
接收參數類型爲torch_geometric.data.Data
,返回轉換後的版本,在數據被存儲到硬盤之前進行轉換(只發生一次)。
(3)pre_filter
接受參數類型爲torch_geometric.data.Data
,返回布爾類型結果,相當於對原始數據的一個mask
。
可以看到InMemoryDataset
中構造函數的參數,這三個函數參數都是None
。這也就是解決了之前的第一個疑問,如果要用pre_filter
,就必須傳遞該參數,否則爲None
。
3.調用InMemoryDataset
的父類Dataset
的構造函數,其實此處就可以發現大部分的邏輯已經可以在Dataset
類中看到了。先對之前的疑惑二進行解答何時調用transform
,爲什麼在process
中沒有transform
呢?
def __getitem__(self, idx):
r"""Gets the data object at index :obj:`idx` and transforms it (in case
a :obj:`self.transform` is given).
In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a
tuple, a LongTensor or a BoolTensor, will return a subset of the
dataset at the specified indices."""
if isinstance(idx, int):
data = self.get(self.indices()[idx])
data = data if self.transform is None else self.transform(data)
return data
else:
return self.index_select(idx)
這一段代碼是源碼Dataset
類中的函數,可以看到這個函數是根據索引獲取部分數據,idx
爲索引目標,可以是列表、元組、LongTensor或者BoolTensor。可以看到只有在訪問數據元素時,纔會調用transform
函數。
4.在Dataset的構造函數中,有這麼幾行代碼:
if 'download' in self.__class__.__dict__.keys():
self._download()
if 'process' in self.__class__.__dict__.keys():
self._process()
此處調用下載函數和處理函數,而self._download()
會調用self.download()
,process
同理。
5.將處理好的數據存儲到本地,然後再加載到程序中。
以上就是詳細的處理流程了,值得注意的是,如果需要下載數據,利用request相關技術,需要自己重寫download()
函數;如果要對數據進行預過濾、轉換和預轉換,需要定義外部函數作爲參數傳遞給構造過程。
4.2 實例學習
看了上面的內容,可能還是不知道咋做,現在就通過官方數據集的源碼進行一波分析。例子以Planetoid
爲例:
from torch_geometric.datasets import Planetoid
1.構造函數中transform
和pre_transform
都設置了None
,但是沒有pre_filter
參數,也就是說這裏不允許傳遞pre_filter
參數。
def __init__(self, root, name, transform=None, pre_transform=None):
self.name = name
super(Planetoid, self).__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])
該數據集只有一個數據文件,所以直接取索引0。
2. 下載函數如下:
def download(self):
for name in self.raw_file_names:
download_url('{}/{}'.format(self.url, name), self.raw_dir)
遍歷每一個文件名,然後調用download_url
函數進行下載。
from torch_geometric.data import download_url
不過在download_url
和Dataset
類中的_download
函數中都進行防覆蓋檢測。
3.處理函數如下:
def process(self):
data = read_planetoid_data(self.raw_dir, self.name)
data = data if self.pre_transform is None else self.pre_transform(data)
torch.save(self.collate([data]), self.processed_paths[0])
第一步讀取數據,第二步轉換,第三步存儲,主要是第一步的操作,這裏調用了一個函數read_planetoid_data
,此函數讀取本地文件後,進行了訓練集、測試集、驗證集的劃分,並且構造了一個Data對象:
data = Data(x=x, edge_index=edge_index, y=y)
data.train_mask = train_mask
data.val_mask = val_mask
data.test_mask = test_mask
在存儲之前調用了
self.collate([data])
該函數的具體內容在下一小節中講解。
4.3 collate函數
collate
函數在InMemoryDataset
中實現,將一個python列表形式數據轉換(每一個元素都是一個數據對象)爲torch_geometric.data.InMemoryDataset
內部存儲數據的格式。這裏每一個數據對象未必是Data類型(一般代表一個Graph),也可以是其他的,比如圖片等。
data = data_list[0].__class__()
這一行代碼可以對第一個元素的類名解析並重新構造一個同類型元素。
for item, key in product(data_list, keys):
data[key].append(item[key])
利用笛卡爾積構造元組替代雙層循環,並且將列表中所有數據元素的值存放到一個數據對象中。後面的代碼進行了一些拼接過程,具體的見Github。