圖神經網絡17-DGL實戰:構建圖神經網絡(GNN)模塊

1 DGL NN模塊的構造函數

構造函數完成以下幾個任務:

  1. 設置選項。
  2. 註冊可學習的參數或者子模塊。
  3. 初始化參數。

    import torch.nn as nn

    from dgl.utils import expand_as_pair

    class SAGEConv(nn.Module):
        def __init__(self,
                     in_feats,
                     out_feats,
                     aggregator_type,
                     bias=True,
                     norm=None,
                     activation=None):
            super(SAGEConv, self).__init__()

            self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
            self._out_feats = out_feats
            self._aggre_type = aggregator_type
            self.norm = norm
            self.activation = activation

在構造函數中,用戶首先需要設置數據的維度。對於一般的PyTorch模塊,維度通常包括輸入的維度、輸出的維度和隱層的維度。
對於圖神經網絡,輸入維度可被分爲源節點特徵維度和目標節點特徵維度。

除了數據維度,圖神經網絡的一個典型選項是聚合類型(self._aggre_type)。對於特定目標節點,聚合類型決定了如何聚合不同邊上的信息。
常用的聚合類型包括 meansummaxmin。一些模塊可能會使用更加複雜的聚合函數,比如 lstm

上面代碼裏的 norm 是用於特徵歸一化的可調用函數。在SAGEConv論文裏,歸一化可以是L2歸一化:
h_v = h_v / \lVert h_v \rVert_2

 # 聚合類型:mean、max_pool、lstm、gcn
if aggregator_type not in ['mean', 'max_pool', 'lstm', 'gcn']:
     raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
 if aggregator_type == 'max_pool':
     self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
if aggregator_type == 'lstm':
      self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
if aggregator_type in ['mean', 'max_pool', 'lstm']:
      self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
      self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
      self.reset_parameters()

註冊參數和子模塊。在SAGEConv中,子模塊根據聚合類型而有所不同。這些模塊是純PyTorch NN模塊,例如 nn.Linearnn.LSTM 等。
構造函數的最後調用了 reset_parameters() 進行權重初始化。

 def reset_parameters(self):
        """重新初始化可學習的參數"""
        gain = nn.init.calculate_gain('relu')
        if self._aggre_type == 'max_pool':
            nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
        if self._aggre_type == 'lstm':
            self.lstm.reset_parameters()
        if self._aggre_type != 'gcn':
            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
            nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)

2 編寫DGL NN模塊的forward函數

在NN模塊中, forward() 函數執行了實際的消息傳遞和計算。與通常以張量爲參數的PyTorch NN模塊相比,
DGL NN模塊額外增加了1個參數 :class:dgl.DGLGraphforward() 函數的內容一般可以分爲3項操作:

  • 檢測輸入圖對象是否符合規範。

  • 消息傳遞和聚合。

  • 聚合後,更新特徵作爲輸出。

下文展示了SAGEConv示例中的 forward() 函數。

輸入圖對象的規範檢測

def forward(self, graph, feat):
        with graph.local_scope():
         # 指定圖類型,然後根據圖類型擴展輸入特徵
         feat_src, feat_dst = expand_as_pair(feat, graph)

forward() 函數需要處理輸入的許多極端情況,這些情況可能導致計算和消息傳遞中的值無效。
比如在 :class:~dgl.nn.pytorch.conv.GraphConv 等conv模塊中,DGL會檢查輸入圖中是否有入度爲0的節點。
當1個節點入度爲0時, mailbox 將爲空,並且聚合函數的輸出值全爲0,
這可能會導致模型性能不佳。但是,在 :class:~dgl.nn.pytorch.conv.SAGEConv 模塊中,被聚合的特徵將會與節點的初始特徵拼接起來,
forward() 函數的輸出不會全爲0。在這種情況下,無需進行此類檢驗。

DGL NN模塊可在不同類型的圖輸入中重複使用,包括:同構圖、異構圖(:ref:guide_cn-graph-heterogeneous)和子圖塊(:ref:guide_cn-minibatch)。

SAGEConv的數學公式如下:

源節點特徵 feat_src 和目標節點特徵 feat_dst 需要根據圖類型被指定。
用於指定圖類型並將 feat 擴展爲 feat_srcfeat_dst 的函數是 :meth:~dgl.utils.expand_as_pair
該函數的細節如下所示。


    def expand_as_pair(input_, g=None):
        if isinstance(input_, tuple):
            # 二分圖的情況
            return input_
        elif g is not None and g.is_block:
            # 子圖塊的情況
            if isinstance(input_, Mapping):
                input_dst = {
                    k: F.narrow_row(v, 0, g.number_of_dst_nodes(k))
                    for k, v in input_.items()}
            else:
                input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes())
            return input_, input_dst
        else:
            # 同構圖的情況
            return input_, input_

對於同構圖上的全圖訓練,源節點和目標節點相同,它們都是圖中的所有節點。

在異構圖的情況下,圖可以分爲幾個二分圖,每種關係對應一個。關係表示爲 (src_type, edge_type, dst_dtype)
當輸入特徵 feat 是1個元組時,圖將會被視爲二分圖。元組中的第1個元素爲源節點特徵,第2個元素爲目標節點特徵。

在小批次訓練中,計算應用於給定的一堆目標節點所採樣的子圖。子圖在DGL中稱爲區塊(block)。
在區塊創建的階段,dst nodes 位於節點列表的最前面。通過索引 [0:g.number_of_dst_nodes()] 可以找到 feat_dst

確定 feat_srcfeat_dst 之後,以上3種圖類型的計算方法是相同的。

消息傳遞和聚合

import dgl.function as fn
import torch.nn.functional as F
from dgl.utils import check_eq_shape
       if self._aggre_type == 'mean':
            graph.srcdata['h'] = feat_src
            graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
            h_neigh = graph.dstdata['neigh']
        elif self._aggre_type == 'gcn':
                check_eq_shape(feat)
                graph.srcdata['h'] = feat_src
                graph.dstdata['h'] = feat_dst
                graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
                # 除以入度
                degs = graph.in_degrees().to(feat_dst)
                h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
            elif self._aggre_type == 'max_pool':
                graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
                graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
                h_neigh = graph.dstdata['neigh']
            else:
                raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))

            # GraphSAGE中gcn聚合不需要fc_self
            if self._aggre_type == 'gcn':
                rst = self.fc_neigh(h_neigh)
            else:
                rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)

上面的代碼執行了消息傳遞和聚合的計算。這部分代碼會因模塊而異。

聚合後,更新特徵作爲輸出

 # 激活函數
 if self.activation is not None:
        rst = self.activation(rst)
       # 歸一化
   if self.norm is not None:
        rst = self.norm(rst)
        return rst

forward() 函數的最後一部分是在完成消息聚合後更新節點的特徵。
常見的更新操作是根據構造函數中設置的選項來應用激活函數和進行歸一化。

3 簡單的圖分類任務

在本教程中,我們將學習如何使用 DGL 執行圖分類,這個例子的任務目標就是對下面顯示的八種拓撲類型Grpah進行分類。


這裏我們直接使用 DGL 中合成數據集 data.MiniGCDataset。數據集有八種不同類型的圖,每個類都有相同數量的圖樣本

from dgl.data import MiniGCDataset
import matplotlib.pyplot as plt
import networkx as nx
# A dataset with 80 samples, each graph is
# of size [10, 20]
dataset = MiniGCDataset(80, 10, 20)
graph, label = dataset[0]
fig, ax = plt.subplots()
nx.draw(graph.to_networkx(), ax=ax)
ax.set_title('Class: {:d}'.format(label))
plt.show()
Using backend: pytorch

創建graph的批數據

import dgl
import torch

def collate(samples):
    # The input `samples` is a list of pairs
    #  (graph, label).
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels,dtype=torch.long)

構建Graph分類器

from dgl.nn.pytorch import GraphConv
import torch.nn as nn
import torch.nn.functional as F

class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super(Classifier, self).__init__()
        self.conv1 = GraphConv(in_dim, hidden_dim)
        self.conv2 = GraphConv(hidden_dim, hidden_dim)
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g):
        # Use node degree as the initial node feature. For undirected graphs, the in-degree
        # is the same as the out_degree.
        h = g.in_degrees().view(-1, 1).float()
        # Perform graph convolution and activation function.
        h = F.relu(self.conv1(g, h))
        h = F.relu(self.conv2(g, h))
        g.ndata['h'] = h
        # Calculate graph representation by averaging all the node representations.
        hg = dgl.mean_nodes(g, 'h')
        return self.classify(hg)
import torch.optim as optim
from torch.utils.data import DataLoader

# Create training and test sets.
trainset = MiniGCDataset(320, 10, 20)
testset = MiniGCDataset(80, 10, 20)
# Use PyTorch's DataLoader and the collate function
# defined before.
data_loader = DataLoader(trainset, batch_size=32, shuffle=True,
                         collate_fn=collate)

# Create model
model = Classifier(1, 256, trainset.num_classes)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()

epoch_losses = []
for epoch in range(80):
    epoch_loss = 0
    for iter, (bg, label) in enumerate(data_loader):
        prediction = model(bg)
        loss = loss_func(prediction, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
    epoch_loss /= (iter + 1)
    print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
    epoch_losses.append(epoch_loss)
Epoch 0, loss 2.0010
Epoch 1, loss 1.9744
Epoch 2, loss 1.9551
Epoch 3, loss 1.9444
Epoch 4, loss 1.9318
Epoch 5, loss 1.9170
Epoch 6, loss 1.8928
Epoch 7, loss 1.8573
Epoch 8, loss 1.8212
Epoch 9, loss 1.7715
Epoch 10, loss 1.7152
Epoch 11, loss 1.6570
Epoch 12, loss 1.5885
Epoch 13, loss 1.5308
Epoch 14, loss 1.4719
Epoch 15, loss 1.4158
Epoch 16, loss 1.3515
Epoch 17, loss 1.2963
Epoch 18, loss 1.2417
Epoch 19, loss 1.1978
Epoch 20, loss 1.1698
Epoch 21, loss 1.1086
Epoch 22, loss 1.0780
Epoch 23, loss 1.0459
Epoch 24, loss 1.0192
Epoch 25, loss 1.0017
Epoch 26, loss 1.0297
Epoch 27, loss 0.9784
Epoch 28, loss 0.9486
Epoch 29, loss 0.9327
Epoch 30, loss 0.9133
Epoch 31, loss 0.9265
Epoch 32, loss 0.9177
Epoch 33, loss 0.9303
Epoch 34, loss 0.8666
Epoch 35, loss 0.8639
Epoch 36, loss 0.8474
Epoch 37, loss 0.8858
Epoch 38, loss 0.8393
Epoch 39, loss 0.8306
Epoch 40, loss 0.8204
Epoch 41, loss 0.8057
Epoch 42, loss 0.7998
Epoch 43, loss 0.7909
Epoch 44, loss 0.7840
Epoch 45, loss 0.7807
Epoch 46, loss 0.7882
Epoch 47, loss 0.7701
Epoch 48, loss 0.7612
Epoch 49, loss 0.7563
Epoch 50, loss 0.7430
Epoch 51, loss 0.7354
Epoch 52, loss 0.7357
Epoch 53, loss 0.7326
Epoch 54, loss 0.7249
Epoch 55, loss 0.7181
Epoch 56, loss 0.7146
Epoch 57, loss 0.7306
Epoch 58, loss 0.7143
Epoch 59, loss 0.7018
Epoch 60, loss 0.7130
Epoch 61, loss 0.7003
Epoch 62, loss 0.6977
Epoch 63, loss 0.7120
Epoch 64, loss 0.6979
Epoch 65, loss 0.7370
Epoch 66, loss 0.7223
Epoch 67, loss 0.6980
Epoch 68, loss 0.6891
Epoch 69, loss 0.6715
Epoch 70, loss 0.6736
Epoch 71, loss 0.6709
Epoch 72, loss 0.6583
Epoch 73, loss 0.6717
Epoch 74, loss 0.6683
Epoch 75, loss 0.6656
Epoch 76, loss 0.6477
Epoch 77, loss 0.6414
Epoch 78, loss 0.6442
Epoch 79, loss 0.6398
plt.title('cross entropy averaged over minibatches')
plt.plot(epoch_losses)
plt.show()
model.eval()
# Convert a list of tuples to two lists
test_X, test_Y = map(list, zip(*testset))
test_bg = dgl.batch(test_X)
test_Y = torch.tensor(test_Y).float().view(-1, 1)
probs_Y = torch.softmax(model(test_bg), 1)
sampled_Y = torch.multinomial(probs_Y, 1)
argmax_Y = torch.max(probs_Y, 1)[1].view(-1, 1)
print('Accuracy of sampled predictions on the test set: {:.4f}%'.format(
    (test_Y == sampled_Y.float()).sum().item() / len(test_Y) * 100))
print('Accuracy of argmax predictions on the test set: {:4f}%'.format(
    (test_Y == argmax_Y.float()).sum().item() / len(test_Y) * 100))
Accuracy of sampled predictions on the test set: 58.7500%
Accuracy of argmax predictions on the test set: 62.500000%
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章