【PyG學習入門】二:入門時遇到的問題

【PyG學習入門】一:入門使用

在上一篇的介紹中,主要講了Pytorch-Geometric的五個基礎用例,但是其中存在一些問題還沒有解決,下面開始一一解決,本文的重點是如何手動加載PyG的數據集。

1.關於創建Data實例時,維度異常的問題

問題描述:
Data創建過程中,edge_index表示邊的信息,x爲節點的特徵向量,y爲目標值,如果y的維度([num_nodes, *])和節點總數的維度是一樣的,那就是node-level;如果y的維度是[1,*],那就是graph-level。但是如果y的維度不符合上述的兩種情況,在創建過程會如何?
解決方案:

### Q1: X維度和Y的維度不統一
import torch
from torch_geometric.data import Data

# 構建邊
edge_index = torch.tensor([
    [3, 1, 1, 2],
    [1, 3, 2, 1]], dtype=torch.long)
# 構建X
x = torch.tensor([[-1],
                  [0],
                  [1],[2]], dtype=torch.float)
y = torch.tensor([[1], [2], [3]], dtype=torch.float)
data = Data(x=x, y=y, edge_index=edge_index)

print(data)

如代碼所示,其中節點共有3個,但是我創建了4個節點的特徵向量和5個目標值,運行代碼後沒有出現錯誤,所以可以得知Data實例化的過程中,是不會檢查數據是否合理的,只是單純的構建了一個複雜數據類型而已。

2.如何加載自己下載的數據集

問題描述:
在使用Dataset進行數據集的創建時,經常會出現HttpError這種樣子的錯誤,所以手動下載數據集之後,再利用PyG的函數進行構建,但是這個方式目前還沒有找到官方的接口,所以要從源碼的角度來處理。
解決方案:
這裏用Cora數據集進行實驗,在planetoid.py文件中可以看到代碼的下載地址爲:

url = 'https://github.com/kimiyoung/planetoid/raw/master/data'

該文件內定義了一個類:

class Planetoid(InMemoryDataset):

裏面有有一個download函數用於下載數據集:

def download(self):
    for name in self.raw_file_names:
        download_url('{}/{}'.format(self.url, name), self.raw_dir)

調用函數download_url指定下載地址和下載後存放的目錄,其中下載的列表爲:

@property
def raw_file_names(self):
    names = ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index']
    return ['ind.{}.{}'.format(self.name.lower(), name) for name in names]

其中一共有八個文件。除了download函數還有一個process函數:

def process(self):
    data = read_planetoid_data(self.raw_dir, self.name)
    data = data if self.pre_transform is None else self.pre_transform(data)
    torch.save(self.collate([data]), self.processed_paths[0])

第一行代碼中的read_planetoid_data函數是進行數據的加載並且切分出訓練集、驗證集、測試集,構造爲Data實例:

data = Data(x=x, edge_index=edge_index, y=y)
data.train_mask = train_mask
data.val_mask = val_mask
data.test_mask = test_mask

第二行代碼表示是否進行pre_transform操作(也就是是否進行一次數據轉換,一般3D點雲數據比較常見);第三行利用torch.save進行本地序列化。經過上面的一段分析,可以確定的是首先進行數據集下載,然後進行處理最後保存到本地一個新的序列化文件,所以只需要在下載過程跳過即可,但是考慮到這麼一點,在之前學習的過程中,可以發現,當你第一次創建完數據集(下載到本地)之後,第二次時間比較短,所以一定存在防覆蓋機制來優化程序運行速度,於是在download_url函數中找到了這一塊代碼:

if osp.exists(path):  # pragma: no cover
    if log:
        print('Using exist file', filename)
    return path

所以總結一下:
(1)根據URL下載自己的數據集;
(2)放到本地文件夾中,格式爲:
在這裏插入圖片描述
其中Cora文件夾是根目錄,processed是處理後torch.save的,所以自己下載的數據放在raw文件夾中;
(3)調用Dataset中的接口創建數據集即可。

from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='data/', name='Cora')
print(dataset)

輸出信息爲:

Processing...
Done!
Cora()

之前GCNModelacc爲:

Accuracy: 0.8080
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章