图神经网络17-DGL实战:构建图神经网络(GNN)模块

1 DGL NN模块的构造函数

构造函数完成以下几个任务:

  1. 设置选项。
  2. 注册可学习的参数或者子模块。
  3. 初始化参数。

    import torch.nn as nn

    from dgl.utils import expand_as_pair

    class SAGEConv(nn.Module):
        def __init__(self,
                     in_feats,
                     out_feats,
                     aggregator_type,
                     bias=True,
                     norm=None,
                     activation=None):
            super(SAGEConv, self).__init__()

            self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
            self._out_feats = out_feats
            self._aggre_type = aggregator_type
            self.norm = norm
            self.activation = activation

在构造函数中,用户首先需要设置数据的维度。对于一般的PyTorch模块,维度通常包括输入的维度、输出的维度和隐层的维度。
对于图神经网络,输入维度可被分为源节点特征维度和目标节点特征维度。

除了数据维度,图神经网络的一个典型选项是聚合类型(self._aggre_type)。对于特定目标节点,聚合类型决定了如何聚合不同边上的信息。
常用的聚合类型包括 meansummaxmin。一些模块可能会使用更加复杂的聚合函数,比如 lstm

上面代码里的 norm 是用于特征归一化的可调用函数。在SAGEConv论文里,归一化可以是L2归一化:
h_v = h_v / \lVert h_v \rVert_2

 # 聚合类型:mean、max_pool、lstm、gcn
if aggregator_type not in ['mean', 'max_pool', 'lstm', 'gcn']:
     raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
 if aggregator_type == 'max_pool':
     self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
if aggregator_type == 'lstm':
      self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
if aggregator_type in ['mean', 'max_pool', 'lstm']:
      self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
      self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
      self.reset_parameters()

注册参数和子模块。在SAGEConv中,子模块根据聚合类型而有所不同。这些模块是纯PyTorch NN模块,例如 nn.Linearnn.LSTM 等。
构造函数的最后调用了 reset_parameters() 进行权重初始化。

 def reset_parameters(self):
        """重新初始化可学习的参数"""
        gain = nn.init.calculate_gain('relu')
        if self._aggre_type == 'max_pool':
            nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
        if self._aggre_type == 'lstm':
            self.lstm.reset_parameters()
        if self._aggre_type != 'gcn':
            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
            nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)

2 编写DGL NN模块的forward函数

在NN模块中, forward() 函数执行了实际的消息传递和计算。与通常以张量为参数的PyTorch NN模块相比,
DGL NN模块额外增加了1个参数 :class:dgl.DGLGraphforward() 函数的内容一般可以分为3项操作:

  • 检测输入图对象是否符合规范。

  • 消息传递和聚合。

  • 聚合后,更新特征作为输出。

下文展示了SAGEConv示例中的 forward() 函数。

输入图对象的规范检测

def forward(self, graph, feat):
        with graph.local_scope():
         # 指定图类型,然后根据图类型扩展输入特征
         feat_src, feat_dst = expand_as_pair(feat, graph)

forward() 函数需要处理输入的许多极端情况,这些情况可能导致计算和消息传递中的值无效。
比如在 :class:~dgl.nn.pytorch.conv.GraphConv 等conv模块中,DGL会检查输入图中是否有入度为0的节点。
当1个节点入度为0时, mailbox 将为空,并且聚合函数的输出值全为0,
这可能会导致模型性能不佳。但是,在 :class:~dgl.nn.pytorch.conv.SAGEConv 模块中,被聚合的特征将会与节点的初始特征拼接起来,
forward() 函数的输出不会全为0。在这种情况下,无需进行此类检验。

DGL NN模块可在不同类型的图输入中重复使用,包括:同构图、异构图(:ref:guide_cn-graph-heterogeneous)和子图块(:ref:guide_cn-minibatch)。

SAGEConv的数学公式如下:

源节点特征 feat_src 和目标节点特征 feat_dst 需要根据图类型被指定。
用于指定图类型并将 feat 扩展为 feat_srcfeat_dst 的函数是 :meth:~dgl.utils.expand_as_pair
该函数的细节如下所示。


    def expand_as_pair(input_, g=None):
        if isinstance(input_, tuple):
            # 二分图的情况
            return input_
        elif g is not None and g.is_block:
            # 子图块的情况
            if isinstance(input_, Mapping):
                input_dst = {
                    k: F.narrow_row(v, 0, g.number_of_dst_nodes(k))
                    for k, v in input_.items()}
            else:
                input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes())
            return input_, input_dst
        else:
            # 同构图的情况
            return input_, input_

对于同构图上的全图训练,源节点和目标节点相同,它们都是图中的所有节点。

在异构图的情况下,图可以分为几个二分图,每种关系对应一个。关系表示为 (src_type, edge_type, dst_dtype)
当输入特征 feat 是1个元组时,图将会被视为二分图。元组中的第1个元素为源节点特征,第2个元素为目标节点特征。

在小批次训练中,计算应用于给定的一堆目标节点所采样的子图。子图在DGL中称为区块(block)。
在区块创建的阶段,dst nodes 位于节点列表的最前面。通过索引 [0:g.number_of_dst_nodes()] 可以找到 feat_dst

确定 feat_srcfeat_dst 之后,以上3种图类型的计算方法是相同的。

消息传递和聚合

import dgl.function as fn
import torch.nn.functional as F
from dgl.utils import check_eq_shape
       if self._aggre_type == 'mean':
            graph.srcdata['h'] = feat_src
            graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
            h_neigh = graph.dstdata['neigh']
        elif self._aggre_type == 'gcn':
                check_eq_shape(feat)
                graph.srcdata['h'] = feat_src
                graph.dstdata['h'] = feat_dst
                graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
                # 除以入度
                degs = graph.in_degrees().to(feat_dst)
                h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
            elif self._aggre_type == 'max_pool':
                graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
                graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
                h_neigh = graph.dstdata['neigh']
            else:
                raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))

            # GraphSAGE中gcn聚合不需要fc_self
            if self._aggre_type == 'gcn':
                rst = self.fc_neigh(h_neigh)
            else:
                rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)

上面的代码执行了消息传递和聚合的计算。这部分代码会因模块而异。

聚合后,更新特征作为输出

 # 激活函数
 if self.activation is not None:
        rst = self.activation(rst)
       # 归一化
   if self.norm is not None:
        rst = self.norm(rst)
        return rst

forward() 函数的最后一部分是在完成消息聚合后更新节点的特征。
常见的更新操作是根据构造函数中设置的选项来应用激活函数和进行归一化。

3 简单的图分类任务

在本教程中,我们将学习如何使用 DGL 执行图分类,这个例子的任务目标就是对下面显示的八种拓扑类型Grpah进行分类。


这里我们直接使用 DGL 中合成数据集 data.MiniGCDataset。数据集有八种不同类型的图,每个类都有相同数量的图样本

from dgl.data import MiniGCDataset
import matplotlib.pyplot as plt
import networkx as nx
# A dataset with 80 samples, each graph is
# of size [10, 20]
dataset = MiniGCDataset(80, 10, 20)
graph, label = dataset[0]
fig, ax = plt.subplots()
nx.draw(graph.to_networkx(), ax=ax)
ax.set_title('Class: {:d}'.format(label))
plt.show()
Using backend: pytorch

创建graph的批数据

import dgl
import torch

def collate(samples):
    # The input `samples` is a list of pairs
    #  (graph, label).
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels,dtype=torch.long)

构建Graph分类器

from dgl.nn.pytorch import GraphConv
import torch.nn as nn
import torch.nn.functional as F

class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super(Classifier, self).__init__()
        self.conv1 = GraphConv(in_dim, hidden_dim)
        self.conv2 = GraphConv(hidden_dim, hidden_dim)
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g):
        # Use node degree as the initial node feature. For undirected graphs, the in-degree
        # is the same as the out_degree.
        h = g.in_degrees().view(-1, 1).float()
        # Perform graph convolution and activation function.
        h = F.relu(self.conv1(g, h))
        h = F.relu(self.conv2(g, h))
        g.ndata['h'] = h
        # Calculate graph representation by averaging all the node representations.
        hg = dgl.mean_nodes(g, 'h')
        return self.classify(hg)
import torch.optim as optim
from torch.utils.data import DataLoader

# Create training and test sets.
trainset = MiniGCDataset(320, 10, 20)
testset = MiniGCDataset(80, 10, 20)
# Use PyTorch's DataLoader and the collate function
# defined before.
data_loader = DataLoader(trainset, batch_size=32, shuffle=True,
                         collate_fn=collate)

# Create model
model = Classifier(1, 256, trainset.num_classes)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()

epoch_losses = []
for epoch in range(80):
    epoch_loss = 0
    for iter, (bg, label) in enumerate(data_loader):
        prediction = model(bg)
        loss = loss_func(prediction, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
    epoch_loss /= (iter + 1)
    print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
    epoch_losses.append(epoch_loss)
Epoch 0, loss 2.0010
Epoch 1, loss 1.9744
Epoch 2, loss 1.9551
Epoch 3, loss 1.9444
Epoch 4, loss 1.9318
Epoch 5, loss 1.9170
Epoch 6, loss 1.8928
Epoch 7, loss 1.8573
Epoch 8, loss 1.8212
Epoch 9, loss 1.7715
Epoch 10, loss 1.7152
Epoch 11, loss 1.6570
Epoch 12, loss 1.5885
Epoch 13, loss 1.5308
Epoch 14, loss 1.4719
Epoch 15, loss 1.4158
Epoch 16, loss 1.3515
Epoch 17, loss 1.2963
Epoch 18, loss 1.2417
Epoch 19, loss 1.1978
Epoch 20, loss 1.1698
Epoch 21, loss 1.1086
Epoch 22, loss 1.0780
Epoch 23, loss 1.0459
Epoch 24, loss 1.0192
Epoch 25, loss 1.0017
Epoch 26, loss 1.0297
Epoch 27, loss 0.9784
Epoch 28, loss 0.9486
Epoch 29, loss 0.9327
Epoch 30, loss 0.9133
Epoch 31, loss 0.9265
Epoch 32, loss 0.9177
Epoch 33, loss 0.9303
Epoch 34, loss 0.8666
Epoch 35, loss 0.8639
Epoch 36, loss 0.8474
Epoch 37, loss 0.8858
Epoch 38, loss 0.8393
Epoch 39, loss 0.8306
Epoch 40, loss 0.8204
Epoch 41, loss 0.8057
Epoch 42, loss 0.7998
Epoch 43, loss 0.7909
Epoch 44, loss 0.7840
Epoch 45, loss 0.7807
Epoch 46, loss 0.7882
Epoch 47, loss 0.7701
Epoch 48, loss 0.7612
Epoch 49, loss 0.7563
Epoch 50, loss 0.7430
Epoch 51, loss 0.7354
Epoch 52, loss 0.7357
Epoch 53, loss 0.7326
Epoch 54, loss 0.7249
Epoch 55, loss 0.7181
Epoch 56, loss 0.7146
Epoch 57, loss 0.7306
Epoch 58, loss 0.7143
Epoch 59, loss 0.7018
Epoch 60, loss 0.7130
Epoch 61, loss 0.7003
Epoch 62, loss 0.6977
Epoch 63, loss 0.7120
Epoch 64, loss 0.6979
Epoch 65, loss 0.7370
Epoch 66, loss 0.7223
Epoch 67, loss 0.6980
Epoch 68, loss 0.6891
Epoch 69, loss 0.6715
Epoch 70, loss 0.6736
Epoch 71, loss 0.6709
Epoch 72, loss 0.6583
Epoch 73, loss 0.6717
Epoch 74, loss 0.6683
Epoch 75, loss 0.6656
Epoch 76, loss 0.6477
Epoch 77, loss 0.6414
Epoch 78, loss 0.6442
Epoch 79, loss 0.6398
plt.title('cross entropy averaged over minibatches')
plt.plot(epoch_losses)
plt.show()
model.eval()
# Convert a list of tuples to two lists
test_X, test_Y = map(list, zip(*testset))
test_bg = dgl.batch(test_X)
test_Y = torch.tensor(test_Y).float().view(-1, 1)
probs_Y = torch.softmax(model(test_bg), 1)
sampled_Y = torch.multinomial(probs_Y, 1)
argmax_Y = torch.max(probs_Y, 1)[1].view(-1, 1)
print('Accuracy of sampled predictions on the test set: {:.4f}%'.format(
    (test_Y == sampled_Y.float()).sum().item() / len(test_Y) * 100))
print('Accuracy of argmax predictions on the test set: {:4f}%'.format(
    (test_Y == argmax_Y.float()).sum().item() / len(test_Y) * 100))
Accuracy of sampled predictions on the test set: 58.7500%
Accuracy of argmax predictions on the test set: 62.500000%
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章