在上一篇的介紹中,主要講了Pytorch-Geometric
的五個基礎用例,但是其中存在一些問題還沒有解決,下面開始一一解決,本文的重點是如何手動加載PyG的數據集。
1.關於創建Data實例時,維度異常的問題
問題描述:
在Data
創建過程中,edge_index
表示邊的信息,x
爲節點的特徵向量,y
爲目標值,如果y
的維度([num_nodes, *]
)和節點總數的維度是一樣的,那就是node-level
;如果y
的維度是[1,*]
,那就是graph-level
。但是如果y
的維度不符合上述的兩種情況,在創建過程會如何?
解決方案:
### Q1: X維度和Y的維度不統一
import torch
from torch_geometric.data import Data
# 構建邊
edge_index = torch.tensor([
[3, 1, 1, 2],
[1, 3, 2, 1]], dtype=torch.long)
# 構建X
x = torch.tensor([[-1],
[0],
[1],[2]], dtype=torch.float)
y = torch.tensor([[1], [2], [3]], dtype=torch.float)
data = Data(x=x, y=y, edge_index=edge_index)
print(data)
如代碼所示,其中節點共有3個,但是我創建了4個節點的特徵向量和5個目標值,運行代碼後沒有出現錯誤,所以可以得知Data
實例化的過程中,是不會檢查數據是否合理的,只是單純的構建了一個複雜數據類型而已。
2.如何加載自己下載的數據集
問題描述:
在使用Dataset
進行數據集的創建時,經常會出現HttpError
這種樣子的錯誤,所以手動下載數據集之後,再利用PyG
的函數進行構建,但是這個方式目前還沒有找到官方的接口,所以要從源碼的角度來處理。
解決方案:
這裏用Cora數據集進行實驗,在planetoid.py
文件中可以看到代碼的下載地址爲:
url = 'https://github.com/kimiyoung/planetoid/raw/master/data'
該文件內定義了一個類:
class Planetoid(InMemoryDataset):
裏面有有一個download
函數用於下載數據集:
def download(self):
for name in self.raw_file_names:
download_url('{}/{}'.format(self.url, name), self.raw_dir)
調用函數download_url
指定下載地址和下載後存放的目錄,其中下載的列表爲:
@property
def raw_file_names(self):
names = ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index']
return ['ind.{}.{}'.format(self.name.lower(), name) for name in names]
其中一共有八個文件。除了download
函數還有一個process
函數:
def process(self):
data = read_planetoid_data(self.raw_dir, self.name)
data = data if self.pre_transform is None else self.pre_transform(data)
torch.save(self.collate([data]), self.processed_paths[0])
第一行代碼中的read_planetoid_data
函數是進行數據的加載並且切分出訓練集、驗證集、測試集,構造爲Data實例:
data = Data(x=x, edge_index=edge_index, y=y)
data.train_mask = train_mask
data.val_mask = val_mask
data.test_mask = test_mask
第二行代碼表示是否進行pre_transform
操作(也就是是否進行一次數據轉換,一般3D點雲數據比較常見);第三行利用torch.save
進行本地序列化。經過上面的一段分析,可以確定的是首先進行數據集下載,然後進行處理最後保存到本地一個新的序列化文件,所以只需要在下載過程跳過即可,但是考慮到這麼一點,在之前學習的過程中,可以發現,當你第一次創建完數據集(下載到本地)之後,第二次時間比較短,所以一定存在防覆蓋機制
來優化程序運行速度,於是在download_url
函數中找到了這一塊代碼:
if osp.exists(path): # pragma: no cover
if log:
print('Using exist file', filename)
return path
所以總結一下:
(1)根據URL下載自己的數據集;
(2)放到本地文件夾中,格式爲:
其中Cora
文件夾是根目錄,processed
是處理後torch.save
的,所以自己下載的數據放在raw文件夾中;
(3)調用Dataset
中的接口創建數據集即可。
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='data/', name='Cora')
print(dataset)
輸出信息爲:
Processing...
Done!
Cora()
之前GCNModel
的acc
爲:
Accuracy: 0.8080