[PyG] 1.如何使用GCN完成一個最基本的訓練過程(含GCN實現)

0. 前言

爲啥要學習Pytorch-Geometric呢?(下文統一簡稱爲PyG) 簡單來說,是目前做的項目有用到,還有1個特點,就是相比NYU的DeepGraphLibrary, DGL的問題是API比較棘手,而且目前沒有遷移的必要性。

圖卷積框架能做的事情比較多,提供了很多方便的數據集和各種GNN SOTA的實現,其實最吸引我的就是這個framework的API比較友好,再加之使用PyG做項目的人比較多,生態對我這種做3D mesh的人比較友好。

注意, 本教程完全基於官方最新 (2020.04.14) 的教程,在此基礎上,完成了簡化版本的GCN的實現,對GCN的官方實現感興趣的童鞋可以康康[1]

下面,我將完全按照[1]的步驟來,不同之處在於,我在這裏將基於PyG的最新版本(1.4.3)來分析GCN的簡化版實現,讓大家更加理解GCN的實現原理, 以下是闡述順序:

  • ①圖數據的Data Handling

  • ②Common Benchmark Datasets

  • ③Mini-batches

  • ④Data Transforms

  • ⑤Learning Methods on Graphs

此外,我所使用的環境是:

  • Ubuntu 18.04
  • Cuda10.0
  • pytorch 1.4.0 conda install pytorch=1.4.0 cudatoolkit=10.0
  • pytorch geometric 1.4.3
  • torch-scatter pip install torch-scatter==latest+cu100 -f https://pytorch-geometric.com/whl/torch-1.4.0.html
  • torch-spline-conv pip install torch-spline-conv==latest+cu100 -f https://pytorch-geometric.com/whl/torch-1.4.0.html
  • torch-cluster pip install torch-cluster==latest+cu100 -f https://pytorch-geometric.com/whl/torch-1.4.0.html
  • torch-sparse pip install torch-sparse==latest+cu100 -f https://pytorch-geometric.com/whl/torch-1.4.0.html

1. 圖結構的數據處理

首先,是什麼?圖是邊和點的相關關係的組合。在PyG中,一個簡單的graph可以被描述爲torch_geometric.data.Data[2]的實例,其中有幾個重要的屬性需要說明,此外,如果你的圖需要擴展,那麼你可以對torch_geometric.data.Data這個類進行修改即可。

在這裏插入圖片描述
圖1.1 torch_geometric.data.Data的常用成員變量

通常來講,對分一般的任務來說,數據類只需要有x,edge_index,edge_attr,y等幾個屬性即可,而且,這些屬性都是optional(可選)的,也就是說,Data類並不侷限於這些屬性。

舉個栗子,可以擴展data.face(torch.LongTensor, [3, num_faces])來保存3D mesh的三角形的連接關係.

在這裏插入圖片描述
圖1.2 torch_geometric.data.Data的官方說明

在這裏插入圖片描述
圖1.3 Data實例(3個節點,4條邊(雙向), 每個節點有2個特徵[-1, 2], [0, 3], [1, 1].)

需要注意的是,儘管圖只有2條邊,我們還是需要定義4個index tuple來考慮邊的雙向關係。
圖1.3搭建的graph的示意圖如下:
在這裏插入圖片描述

2. 常見Benchmark數據集

儘管最近Bengio團隊是基於DGL開發的6個Benchmark數據集,但是在pyG上做這個也沒問題呀~。所以也不必直接因此就轉去DGL。

PyTorch Geometric包含了大量的基礎數據集, 所有的Planetoid datasets (Cora, Citeseer, Pubmed), 來自多特蒙德工大的清洗過的圖分類數據集, 一系列3D點雲和mesh的數據集,比如FAUST,ShapeNet等。

PyG提供了這些數據的自動下載,並將其處理成之前說的Data形式,以ENZYMES數據集爲例(包含600個圖和6個類別):
在這裏插入圖片描述
圖2.1 ENZYMES數據集的解析

由圖2.1可見,其中的每個樣本都是Data的instance,有頂點特徵x,連接關係edge_index以及類別y 3個屬性. 可以看出,ENZYMES的每個數據都是1個圖。

注意: 可以通過使用dataset=dataset.shuffle()來對數據集進行shuffle。

此外,教程上還提供了Planetoid的Cora數據集的說明(用於semi-supervised graph node classification), 這裏Cora數據集的數據有3個新的屬性train_mask, test_mask, val_mask, 這3個屬性用於表徵需要訓練、測試和驗證的數據節點。

Cora與ENZYMES的區別是,Cora中的每個數據是整個圖中的1個節點,而ENZYMES的每個數據都是1個獨立的圖。

在這裏插入圖片描述
圖2.2 Cora數據集說明

3. Mini-Batches

我們知道,神經網絡通常是按Batch訓練的,PyG通過創建稀疏的鄰接矩陣(sparse block diagnol adjacency matrices)實現在mini-batch上的並行化。

在這裏插入圖片描述
圖3.1 PyG mini-batch對不同的節點、邊數量的圖的批處理

並按照node dimension來拼接節點特徵x和類別特徵y。通過這種方式,PyG可以在一個Batch中塞進不同nodes和edges數的樣本。
在這裏插入圖片描述
圖3.2 ENZYMES數據集加載說明(未shuffle)

(注意,這裏的DataLoader用的是PyG自己的,而不是pytorch的,此外,use_node_attr=False時, x爲[nodes_num, 3]; use_node_attr=True時, x爲[nodes_num, 21])

這裏,torch_geometric.data.Batch繼承自 torch_geometric.data.Data,多了一個名爲batch的屬性,其作用是標示每個節點屬於哪個圖(ENZYMES)/樣本.

此外,torch_geometric.data.DataLoader也只是pytorch的Dataloader重寫了collate函數的版本而已。

正常傳遞給pytorch的Dataloader的參數,如pin_memory,num_workers等都可以傳給torch_geometric.data.DataLoader.

當然,用戶可以通過使用torch-scatter[3]對節點數據特徵x進行自定義的處理並使用自定義的Dataset和Dataloader來處理自己的特殊形式數據[4].

4. Data Transforms

torchvision在pytorch中的使用類似,我們也需要對graph數據進行處理和變換。PyG提供了自己的transform方式和工具包,要求的輸入爲Data對象,並返回transformed的Data對象。

類似地,transform可以通過torch_geometric.transforms.Compose來進行一系列的拼接。

作者舉得例子是ShapeNet數據集(包含17,000 3D shape point clouds and per point labels from 16 shape categories)的Airplane類,作者通過pre_transform = T.KNNGraph(k=6)將point cloud數據變爲graph數據集。

在這裏插入圖片描述
圖4.1 ShapeNet數據集處理(將點雲數據變爲graph數據)

如有其它需要,用戶可以自己去torch_geometric.transforms進行查閱是否有符合自己目的的transform,沒有的話自己寫~

5. Learning Methods on Graphs

在搞定前4步後,現在讓我們開始搞起第1個GNN~,這裏,我們將會使用最基礎的GCN層來複現Cora Citation數據集上的實驗,若要理解GCN,需要從Fourier變換講起,類比time domain --> frequency domain, 經過Hemlholtz公式,將vertex domain變到 spectral domain來分析,這樣一來,vertex domain的卷積就變成了spectral domain的點乘,節省了計算量。

此外,變換的過程中, 還涉及到Laplacian矩陣L的意義(每個vertex的散度Divergence:可以理解爲每個vertex的信息的增益情況出射爲正,入射爲負),因爲L的性質(半正定,特徵值大於等於0等),假設其特徵值爲λλ,特徵向量爲UU,通過與頻譜圖對比:

  • UU就可以類比爲Fourier變換的basis函數;
  • λλ就類比爲頻率w

GCN[5]就是在此基礎上,經由2步優化得到的,它既考慮了self-loop,也考慮了k-localize(局部性),還對度進行了renormalization,避免馬太效應過於明顯,使得模型不會很容易陷入local minima。

好了,就不再多提了,對GCN的推導和出現感興趣的,可以看[6-7](先理解Laplacian矩陣和變換在圖論中的一般含義, 再去油管上看臺灣大學姜成翰助教關於GNN的教程)進行學習,下面我們看代碼。

5.1 GCN在PyG的實現

PyG提供了torch_geometric.nn.MessagePassing這個base class,通過繼承這個類,我們可以實現各種基於消息傳遞的GNN,藉由MessagePassing
用戶不再需要關注message的progation的內部實現細節,MessagePassing主要關注其UPDATE, AGGREGATION, MESSAGE 這3個成員函數。

用戶在實現自己的GNN時,一般只overwrite AGGREGATION, UPDATE這2個成員函數,MESSAGE/Propagate用MessagePassing自帶的。(官方的GCN就是這樣的~)

我們的目標是: 實現1個與官方一致的簡化版的GCN,並通過實現它來掌握如何在PyG中定義圖卷積。

  • 首先,我們先定義一個圖數據data(有向圖, 4個節點,3條邊, 每個節點的特徵維度都是1, 值也都爲1):
import torch
from torch_geometric.data import Data

edge_index = torch.tensor([[1, 2, 3], [0, 0, 0]], dtype=torch.long)
x = torch.tensor([[1], [1], [1], [1]], dtype=torch.float)

data = Data(edge_index=edge_index, x=x)
print(edge_index)
print(data)

在這裏插入圖片描述

  • MessagePassing消息傳遞機制

在這裏插入圖片描述
通過上面這個圖[9],很容易瞭解到基於消息傳遞的GCN的每1塊對應的內容是什麼:
message = ϕ\phi; aggregation = ; Update = γ\gamma. 那麼替換上面的消息傳遞公式,得到如下的新形式:
在這裏插入圖片描述
在這裏插入圖片描述
此圖就表示了本例的數據的流轉方式,需要注意: GCN默認的scatter方式是add(至於爲啥,請看下圖: 因爲用meanmax的情況下,每個subgraph隨着GNN網絡層數的加大,其中各個node之前的特徵區分度越來越小,這不符合我們的目標)
在這裏插入圖片描述

  • GCN的實現也可以分爲5步:
    在這裏插入圖片描述
  1. Add self-loops to the adjacency matrix. edge_index = A^=IN+A\hat A = I_{N} + A (代碼裏通過修改edge_index實現), PyG源代碼中通過add_remaining_self_loops函數來實現[10]

edge_index
在這裏插入圖片描述

加self_loop後的edge_index
在這裏插入圖片描述

  1. Linearly transform node feature matrix. x = ΘW\Theta W (代碼裏對應self.matmul(weight, x))
    原輸入x
    在這裏插入圖片描述
    經過weight transform得到的x
    在這裏插入圖片描述

  2. Normalize node features. norm = D^0.5A^D^0.5\hat D^{-0.5}\hat A \hat D^{-0.5} (在源碼中A^\hat A用以表示邊的權重,默認情況下都是1.)

norm的值,其長度同edge_index的一致:
在這裏插入圖片描述

  1. Sum up neighboring node features. iN(p)(D^0.5A^D^0.5ΘW)\sum_{i ∈ N(p)}(\hat D^{-0.5}\hat A \hat D^{-0.5} \Theta W)
    (第4步在MessagePassing裏面實現,即上面的圖中的scatter_add/sum/mean, 用戶無需操心) 在def message(self, x_j, norm)中的x_j就是第2步x的擴展到self_loop的結果.

x_j的值:
在這裏插入圖片描述
message(self, x_j, norm)的輸出:
在這裏插入圖片描述

  1. Return new node embeddings. 返回得到的結果XnewX_{new}.

因爲是scatter_add的方式,所以將[1, 0], [2, 0], [3, 0]的連接關係相加,得到最終輸出結果:

顯然
0.0330=0.00750.00750.00750.0106-0.0330 = -0.0075 -0.0075 -0.0075-0.0106
2.3680=0.5364+0.5364+0.5364+0.75862.3680 = 0.5364+0.5364+0.5364+0.7586
在這裏插入圖片描述
同樣,若改成scatter_max的話,結果爲如下,因爲0.0075=max(0.0075,0.0106)-0.0075 = max(-0.0075, -0.0106), 0.7586=max(0.5364,0.7586)0.7586 = max(0.5364, 0.7586)
在這裏插入圖片描述

這五步的實現通過如下代碼完整實現:

import torch
from torch_scatter import scatter_add
from torch_geometric.nn import MessagePassing
import math

def glorot(tensor):
    if tensor is not None:
        stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
        tensor.data.uniform_(-stdv, stdv)


def zeros(tensor):
    if tensor is not None:
        tensor.data.fill_(0)

        
def add_self_loops(edge_index, num_nodes=None):
    print("進入self_loops")
    loop_index = torch.arange(0, num_nodes, dtype=torch.long,
                              device=edge_index.device)
    print(loop_index)
    loop_index = loop_index.unsqueeze(0).repeat(2, 1)
    print(loop_index)
    
    edge_index = torch.cat([edge_index, loop_index], dim=1)
    print(edge_index)
    print("出self_loops")
	# 原來的edge_index爲[[1, 2, 3],
	#                   [0, 0, 0]]
    #  這樣一來,就在原來的邊連接關係edge_index的基礎上增加了self_loop的關係.
    #  torch.cat([edge_index, loop_index], dim=1)
    #      tensor([[1, 2, 3, 0, 1, 2, 3],
    #              [0, 0, 0, 0, 1, 2, 3]])

    
    return edge_index


def degree(index, num_nodes=None, dtype=None):
    out = torch.zeros((num_nodes), dtype=dtype, device=index.device)
    print(out.scatter_add_(0, index, out.new_ones((index.size(0)))))
    return out.scatter_add_(0, index, out.new_ones((index.size(0))))
        

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels, bias=True):
    
        super(GCNConv, self).__init__(aggr='add')  # "Add" aggregation.
        # super(GCNConv, self).__init__(aggr='max')  # "Max" aggregation.
        
        self.weight = torch.nn.Parameter(torch.Tensor(in_channels, out_channels))

        if bias:
            self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        
        self.reset_parameters()
        
    def reset_parameters(self):
        glorot(self.weight)
        zeros(self.bias)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]
        
        # Step 1: 爲adjacency matrix添加self_loop(通過對edge_index拼接連向自己的邊[1, 1], [2, 2]等)
        # 原來的edge_index = tensor([[1, 2, 3],
        #                           [0, 0, 0]])
        # 加上self_loop的index = tensor([[1, 2, 3, 0, 1, 2, 3],
        #                               [0, 0, 0, 0, 1, 2, 3]])
        edge_index = add_self_loops(edge_index, x.size(0))

        # Step 2: 對輸入的node feature matrix進行weight transform.
        x = torch.matmul(x, self.weight)

        # Step 3-5: 開始消息傳遞.
        edge_weight = torch.ones((edge_index.size(1),), 
                                  dtype=x.dtype,
                                  device=edge_index.device)
        row, col = edge_index
        print("row", row)  # row tensor([1, 2, 3, 0, 1, 2, 3])
        print("col", col)  # col tensor([0, 0, 0, 0, 1, 2, 3])
        deg = scatter_add(edge_weight, row, dim=0, dim_size=x.size(0))
        print("deg", deg)  
        # deg是[1, 2, 2, 2], 這是啥?
        # 因爲
        # row = [1, 2, 3, 0, 1, 2, 3]
        # edge_weight = [1, 1, 1, 1, 1, 1, 1]
        # 所以,主對角上,第0個對應1,第1個對應2個,同理,得到degree矩陣. 這裏只返回主對角的元素, 避免稀疏乘.
        
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        
        # 讀edge_weight爲None的情況, 
        # deg_inv_sqrt[row] * edge_weight *  deg_inv_sqrt[col] == deg_inv_sqrt[row] *  deg_inv_sqrt[col]
        norm = deg_inv_sqrt[row] * edge_weight *  deg_inv_sqrt[col]
        print(norm)
        # norm = tensor([0.7071, 0.7071, 0.7071, 1.0000, 0.5000, 0.5000, 0.5000])
        
        return self.propagate(edge_index, x=x, norm=norm)           


    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]
        # norm: 規則化後的權重.
        return norm.view(-1, 1) * x_j if norm is not None else x_j                  
        


    def update(self, aggr_out):
        # aggr_out has shape [N, out_channels]

        # Step 5: 返回新的node embeddings.
        if self.bias is not None:
            return aggr_out + self.bias
        else:
            return aggr_out

進行實驗,得到與官方實現一樣的效果:
在這裏插入圖片描述

5.2 在Cora Citation數據集上進行訓練

這裏用回官方的GCN來做1個2層的GNN網絡對Cora Citation數據集進行訓練,如果一切ok,下面代碼直接複製到你的本地就可以跑起來~

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid

# 5.1) 加載Cora數據集.(自動幫你下載)
dataset = Planetoid(root='/home/pyG/Cora', name='Cora')

# 5.2) 定義2層GCN的網絡.
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        return F.log_softmax(x, dim=1)
        
# 5.3) 訓練 & 測試.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
model.eval()
_, pred = model(data).max(dim=1)
correct = float (pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct / data.test_mask.sum().item()
print('Accuracy: {:.4f}'.format(acc))
# >>> Accuracy: 0.8150

到這步,一個完整的基於GCN的GNN就搞定了,至於訓練的數據處理和很多細節,需要大家hack源碼啦,祝大家學習愉快~

參考資料

[1] PyG官方Tutorial
[2] torch_geometric.data.Data
[3] torch-scatter
[4] advanced mini-batching of PyG
[5] GCN: Semi-supervised Classfication with Graph Convolutional Networks
[6] [其實賊簡單] 拉普拉斯算子和拉普拉斯矩
[7] GNN介紹 臺灣大學 姜成翰
[8] Torch geometric GCNConv 源碼分析
[9] MessagePassing
[10] add_remaining_self_loops

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章