【Code】GraphSAGE 源碼解析

1.GraphSAGE

本文代碼源於 DGL 的 Example 的,感興趣可以去 github 上面查看。

閱讀代碼的本意是加深對論文的理解,其次是看下大佬們實現算法的一些方式方法。當然,在閱讀 GraphSAGE 代碼時我也發現了之前忽視的 GraphSAGE 的細節問題和一些理解錯誤。比如說:之前忽視了 GraphSAGE 的四種聚合方式的具體實現,對 Alogrithm 2 的算法理解也有問題,再回頭看那篇 GraphSAGE 的推文時,實在慘不忍睹= =。

進入正題,簡單回顧一下 GraphSAGE。

核心算法:

2.SAGEConv

dgl 已經實現了 SAGEConv 層,所以我們可以直接導入。

有了 SAGEConv 層後,GraphSAGE 實現起來就比較簡單。

和基於 GraphConv 實現 GCN 的唯一區別在於把 GraphConv 改成了 SAGEConv:

class GraphSAGE(nn.Module):
    def __init__(self,
                 g,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout,
                 aggregator_type):
        super(GraphSAGE, self).__init__()
        self.layers = nn.ModuleList()
        self.g = g
        # input layer
        self.layers.append(SAGEConv(in_feats, n_hidden, aggregator_type,
                                    feat_drop=dropout, activation=activation))
        # hidden layers
        for i in range(n_layers - 1):
            self.layers.append(SAGEConv(n_hidden, n_hidden, aggregator_type,
                                        feat_drop=dropout, activation=activation))
        # output layer
        self.layers.append(SAGEConv(n_hidden, n_classes, aggregator_type,
                                    feat_drop=dropout, activation=None)) # activation None
        
    def forward(self, features):
        h = features
        for layer in self.layers:
            h = layer(self.g, h)
        return h

來看一下 SAGEConv 是如何實現的

SAGEConv 接收七個參數:

  • in_feat:輸入的特徵大小,可以是一個整型數,也可以是兩個整型數。如果用在單向二部圖上,則可以用整型數來分別表示源節點和目的節點的特徵大小,如果只給一個的話,則默認源節點和目的節點的特徵大小一致。需要注意的是,如果參數 aggregator_type 爲 gcn 的話,則源節點和目的節點的特徵大小必須一致;
  • out_feats:輸出特徵大小;
  • aggregator_type:聚合類型,目前支持 mean、gcn、pool、lstm,比論文多一個 gcn 聚合,gcn 聚合可以理解爲周圍所有的鄰居結合和當前節點的均值;
  • feat_drop=0.:特徵 drop 的概率,默認爲 0;
  • bias=True:輸出層的 bias,默認爲 True;
  • norm=None:歸一化,可以選擇一個歸一化的方式,默認爲 None
  • activation=None:激活函數,可以選擇一個激活函數去更新節點特徵,默認爲 None。
class SAGEConv(nn.Module):
    

    def __init__(self,
                 in_feats,
                 out_feats,
                 aggregator_type,
                 feat_drop=0.,
                 bias=True,
                 norm=None,
                 activation=None):
        super(SAGEConv, self).__init__()

        # expand_as_pair 函數可以返回一個二維元組。
        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.feat_drop = nn.Dropout(feat_drop)
        self.activation = activation
        # aggregator type: mean/pool/lstm/gcn
        if aggregator_type == 'pool':
            self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
        if aggregator_type == 'lstm':
            self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
        if aggregator_type != 'gcn':
            self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
        self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
        self.reset_parameters()

    def reset_parameters(self):
        """初始化參數
        這裏的 gain 可以從 calculate_gain 中獲取針對非線形激活函數的建議的值
        用於初始化參數
        """
        gain = nn.init.calculate_gain('relu')
        if self._aggre_type == 'pool':
            nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
        if self._aggre_type == 'lstm':
            self.lstm.reset_parameters()
        if self._aggre_type != 'gcn':
            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)

    def _lstm_reducer(self, nodes):
        """LSTM reducer
        NOTE(zihao): lstm reducer with default schedule (degree bucketing)
        is slow, we could accelerate this with degree padding in the future.
        """
        m = nodes.mailbox['m'] # (B, L, D)
        batch_size = m.shape[0]
        h = (m.new_zeros((1, batch_size, self._in_src_feats)),
             m.new_zeros((1, batch_size, self._in_src_feats)))
        _, (rst, _) = self.lstm(m, h)
        return {'neigh': rst.squeeze(0)}

    def forward(self, graph, feat):
        """ SAGE 層的前向傳播
        接收 DGLGraph 和 Tensor 格式的節點特徵
        """
        # local_var 會返回一個作用在內部函數中使用的 Graph 對象
        # 外部數據的變化不會影響到這個 Graph
        # 可以理解爲保護數據不被意外修改
        graph = graph.local_var()

        if isinstance(feat, tuple):
            feat_src = self.feat_drop(feat[0])
            feat_dst = self.feat_drop(feat[1])
        else:
            feat_src = feat_dst = self.feat_drop(feat)

        h_self = feat_dst

        # 根據不同的聚合類型選擇不同的聚合方式
        # 值得注意的是,論文在 3.3 節只給出了三種聚合方式
        # 而這裏卻多出來一個 gcn 聚合
        # 具體原因將在後文給出
        if self._aggre_type == 'mean':
            graph.srcdata['h'] = feat_src
            graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh'))
            h_neigh = graph.dstdata['neigh']
        elif self._aggre_type == 'gcn':
            # check_eq_shape 用於檢查源節點和目的節點的特徵大小是否一致
            check_eq_shape(feat)
            graph.srcdata['h'] = feat_src
            graph.dstdata['h'] = feat_dst     # same as above if homogeneous
            graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh'))
            # divide in_degrees
            degs = graph.in_degrees().to(feat_dst)
            h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
        elif self._aggre_type == 'pool':
            graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
            graph.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh'))
            h_neigh = graph.dstdata['neigh']
        elif self._aggre_type == 'lstm':
            graph.srcdata['h'] = feat_src
            graph.update_all(fn.copy_src('h', 'm'), self._lstm_reducer)
            h_neigh = graph.dstdata['neigh']
        else:
            raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))

        # GraphSAGE GCN does not require fc_self.
        if self._aggre_type == 'gcn':
            rst = self.fc_neigh(h_neigh)
        else:
            rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
        # activation
        if self.activation is not None:
            rst = self.activation(rst)
        # normalization
        if self.norm is not None:
            rst = self.norm(rst)
        return rst

reset_parameters 函數那裏有一個 gain,初始化參數服從 Xavier 均勻分佈:

WU[gain6nj+nj+1,gaingain6nj+nj+1] W \sim U[- \frac{\text{gain} \sqrt{6}}{\sqrt{n_j+n_{j+1}}}, \text{gain} \frac{\text{gain} \sqrt{6}}{\sqrt{n_j+n_{j+1}}}] \\
仔細閱讀論文時會發現,在實驗部分作者給出了四種方式的聚合方法:

配合着論文,我們來閱讀下代碼

  1. MEAN 聚合器:首先對鄰居節點進行均值聚合,然後當前節點特徵與鄰居節點特徵該分別送入全連接網絡後相加得到結果,對應僞代碼如下:

hN(v)kMEANk({huk1,uN(v)})hvkσ(WkCONCAT({hvk1,hN(v)k}) h_{N(v)}^k \leftarrow \text{MEAN}_k(\{ \mathbf{h}_u^{k-1}, \forall u \in N(v )\}) \\ h_v^k \leftarrow \sigma(\mathbf{W^k} \cdot \text{CONCAT}(\{\mathbf{h}_v^{k-1}, h_{N(v)}^k\} ) \\

對應代碼如下:

h_self = feat_dst
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
# 公式裏寫的是 concat,這裏是 element-wise 的和。
# 稍微有些出入,不過問題不大。
rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
  1. GCN 聚合:首先對鄰居節點的特徵和自身節點的特徵求均值,得到的聚合特徵送入到全連接網絡中,對應論文公式如下:

hvkσ(WMEAN({hvk1}huk1,uN(v)}) h_v^k \leftarrow \sigma(\mathbf{W} \cdot \text{MEAN}(\{\mathbf{h}_v^{k-1}\} \cup \mathbf{h}_u^{k-1}, \forall u \in N(v )\} )

對應代碼如下:

graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst   
graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh'))
# divide in_degrees
degs = graph.in_degrees().to(feat_dst)
# 公式中給出集合並集,這裏做 element-wise 的和,問題也不大。
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
rst = self.fc_neigh(h_neigh)

gcn 與 mean 的關鍵區別在於節點鄰居節點和當前節點取平均的方式:gcn 是直接將當前節點和鄰居節點取平均,而 mean 聚合是 concat 當前節點的特徵和鄰居節點的特徵,所以前者只經過一個全連接層,而後者是分別經過全連接層

這裏利用下斯坦福大學的同學實現的 GCN 聚合器的解釋,如果不明白的話,可以去其 github 倉庫查看源碼:

class MeanAggregator(Layer):
    """
    Aggregates via mean followed by matmul and non-linearity.
    """

class GCNAggregator(Layer):
    """
    Aggregates via mean followed by matmul and non-linearity.
    Same matmul parameters are used self vector and neighbor vectors.
    """
  1. POOL 聚合器:池化方法中,每一個節點的向量都會對應一個全連接神經網絡,然後基於 elementwise 取最大池化操作。對應公式如下:

AGGREGATEkpool=max({Wpoolhuik+b,uiN(v)}) \text{AGGREGATE}_k^{pool} = \text{max}( \{\mathbf{W}_{pool} \mathbf{h}_{u_i}^k + \mathbf{b}, \forall u_i \in N(v) \} ) \\

對應代碼如下:

graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
graph.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
  1. LSTM 聚合器:其表達能力比 mean 聚合器要強,但是 LSTM 是非對稱的,即其考慮節點的順序性,論文作者通過將節點進行隨機排列來調整 LSTM 對無序集的支持。
def _lstm_reducer(self, nodes):
  """LSTM reducer
  """
  m = nodes.mailbox['m'] # (B, L, D)
  batch_size = m.shape[0]
  h = (m.new_zeros((1, batch_size, self._in_src_feats)),
       m.new_zeros((1, batch_size, self._in_src_feats)))
  _, (rst, _) = self.lstm(m, h)
  return {'neigh': rst.squeeze(0)}

graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src('h', 'm'), self._lstm_reducer)
h_neigh = graph.dstdata['neigh']

以上便是利用 SAGEConv 實現 GraphSAGE 的方法,剩餘訓練的內容與前文介紹 GCN 一致,不再進行介紹。

3.Neighbor sampler

這裏再介紹一種基於節點鄰居採樣並利用 minibatch 的方法進行前向傳播的實現。

這種方法適用於大圖,並且能夠並行計算。

import dgl
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from dgl.nn import SAGEConv
import time
from dgl.data import RedditDataset
import tqdm

首先是鄰居採樣(NeighborSampler),這個最好配合着 PinSAGE 的實現來看:

我們關注下上半部分,首先對節點 A 的一階鄰居進行採樣,然後再進行二階鄰居採樣,節點 A 的二階鄰居可能會包括節點 A 及其一階鄰居。

Neighbor Sampler 函數的實現目的與之類似,首先獲取最右邊的種子節點,然後依次進行一階採樣和二階採樣。採樣的方向是從左到右,而特徵聚合方向是從從右到左。

class NeighborSampler(object):
    def __init__(self, g, fanouts):
        """
        g 爲 DGLGraph;
        fanouts 爲採樣節點的數量,實驗使用 10,25,指一階鄰居採樣 10 個,二階鄰居採樣 25 個。
        """
        self.g = g
        self.fanouts = fanouts

    def sample_blocks(self, seeds):
        seeds = th.LongTensor(np.asarray(seeds))
        blocks = []
        for fanout in self.fanouts: 
            # sample_neighbors 可以對每一個種子的節點進行鄰居採樣並返回相應的子圖
            # replace=True 表示用採樣後的鄰居節點代替所有鄰居節點
            frontier = dgl.sampling.sample_neighbors(g, seeds, fanout, replace=True)
            # 將圖轉變爲可以用於消息傳遞的二部圖(源節點和目的節點)
            # 其中源節點的 id 也可能包含目的節點的 id(原因上面說了)
            # 轉變爲二部圖主要是爲了方便進行消息傳遞
            block = dgl.to_block(frontier, seeds)
            # 獲取新圖的源節點作爲種子節點,爲下一層作準備
            # 之所以是從 src 中獲取種子節點,是因爲採樣操作相對於聚合操作來說是一個逆向操作
            seeds = block.srcdata[dgl.NID]
            # 把這一層放在最前面。
            # PS:如果數據量大的話,插入操作是不是不太友好。
            blocks.insert(0, block)
        return blocks

Algorithm 2 僞代碼如下所示,NeighborSampler 對應 Algorithm 2 算法的 1-7 步:

# GraphSAGE 的代碼實現
class GraphSAGE(nn.Module):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout):
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
        self.layers.append(SAGEConv(in_feats, n_hidden, 'mean'))
        for i in range(1, n_layers - 1):
            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
        self.layers.append(SAGEConv(n_hidden, n_classes, 'mean'))
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

    def forward(self, blocks, x):
        # block 是我們採樣獲得的二部圖,這裏用於消息傳播
        # x 爲節點特徵
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h_dst = h[:block.number_of_dst_nodes()]
            h = layer(block, (h, h_dst))
            if l != len(self.layers) - 1:
                h = self.activation(h)
                h = self.dropout(h)
        return h

    def inference(self, g, x, batch_size, device):
        # inference 用於評估測試,針對的是完全圖
        # 目前會出現重複計算的問題,優化方案還在 to do list 上
        nodes = th.arange(g.number_of_nodes())
        for l, layer in enumerate(self.layers):
            y = th.zeros(g.number_of_nodes(), 
                         self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
            for start in tqdm.trange(0, len(nodes), batch_size):
                end = start + batch_size
                batch_nodes = nodes[start:end]
                block = dgl.to_block(dgl.in_subgraph(g, batch_nodes), batch_nodes)
                input_nodes = block.srcdata[dgl.NID]
                h = x[input_nodes].to(device)
                h_dst = h[:block.number_of_dst_nodes()]
                h = layer(block, (h, h_dst))
                if l != len(self.layers) - 1:
                    h = self.activation(h)
                    h = self.dropout(h)
                y[start:end] = h.cpu()
            x = y
        return y
def compute_acc(pred, labels):
    """
    計算準確率
    """
    return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)

def evaluate(model, g, inputs, labels, val_mask, batch_size, device):
    """
    評估模型,調用 model 的 inference 函數
    """
    model.eval()
    with th.no_grad():
        pred = model.inference(g, inputs, batch_size, device)
    model.train()
    return compute_acc(pred[val_mask], labels[val_mask])

def load_subtensor(g, labels, seeds, input_nodes, device):
    """
    將一組節點的特徵和標籤複製到 GPU 上。
    """
    batch_inputs = g.ndata['features'][input_nodes].to(device)
    batch_labels = labels[seeds].to(device)
    return batch_inputs, batch_labels
# 參數設置
gpu = -1
num_epochs = 20
num_hidden = 16
num_layers = 2
fan_out = '10,25'
batch_size = 1000
log_every = 20  # 記錄日誌的頻率
eval_every = 5
lr = 0.003
dropout = 0.5
num_workers = 0  # 用於採樣進程的數量

if gpu >= 0:
    device = th.device('cuda:%d' % gpu)
else:
    device = th.device('cpu')

# load reddit data
# NumNodes: 232965
# NumEdges: 114848857
# NumFeats: 602
# NumClasses: 41
# NumTrainingSamples: 153431
# NumValidationSamples: 23831
# NumTestSamples: 55703
data = RedditDataset(self_loop=True)
train_mask = data.train_mask
val_mask = data.val_mask
features = th.Tensor(data.features)
in_feats = features.shape[1]
labels = th.LongTensor(data.labels)
n_classes = data.num_labels
# Construct graph
g = dgl.graph(data.graph.all_edges())
g.ndata['features'] = features

開始訓練:

train_nid = th.LongTensor(np.nonzero(train_mask)[0])
val_nid = th.LongTensor(np.nonzero(val_mask)[0])
train_mask = th.BoolTensor(train_mask)
val_mask = th.BoolTensor(val_mask)

# Create sampler
sampler = NeighborSampler(g, [int(fanout) for fanout in fan_out.split(',')])

# Create PyTorch DataLoader for constructing blocks
# collate_fn 參數指定了 sampler,可以對 batch 中的節點進行採樣
dataloader = DataLoader(
    dataset=train_nid.numpy(),
    batch_size=batch_size,
    collate_fn=sampler.sample_blocks,
    shuffle=True,
    drop_last=False,
    num_workers=num_workers)

# Define model and optimizer
model = GraphSAGE(in_feats, num_hidden, n_classes, num_layers, F.relu, dropout)
model = model.to(device)
loss_fcn = nn.CrossEntropyLoss()
loss_fcn = loss_fcn.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

# Training loop
avg = 0
iter_tput = []
for epoch in range(num_epochs):
    tic = time.time()

    for step, blocks in enumerate(dataloader):
        tic_step = time.time()

        input_nodes = blocks[0].srcdata[dgl.NID]
        seeds = blocks[-1].dstdata[dgl.NID]

        # Load the input features as well as output labels
        batch_inputs, batch_labels = load_subtensor(g, labels, seeds, input_nodes, device)

        # Compute loss and prediction
        batch_pred = model(blocks, batch_inputs)
        loss = loss_fcn(batch_pred, batch_labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        iter_tput.append(len(seeds) / (time.time() - tic_step))
        if step % log_every == 0:
            acc = compute_acc(batch_pred, batch_labels)
            gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0
            print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MiB'.format(
                epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]), gpu_mem_alloc))

    toc = time.time()
    print('Epoch Time(s): {:.4f}'.format(toc - tic))
    if epoch >= 5:
        avg += toc - tic
    if epoch % eval_every == 0 and epoch != 0:
        eval_acc = evaluate(model, g, g.ndata['features'], labels, val_mask, batch_size, device)
        print('Eval Acc {:.4f}'.format(eval_acc))

print('Avg epoch time: {}'.format(avg / (epoch - 4)))

4.Reference

  1. Github:dmlc/dgl
  2. williamleif/GraphSAGE

關注公衆號跟蹤最新內容:阿澤的學習筆記

阿澤的學習筆記

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章