圖神經網絡11-GCN落地的必讀論文:GraphSAGE

1 GraphSAGE論文簡介

論文:Inductive Representation Learning on Large Graphs 在大圖上的歸納表示學習
鏈接:https://arxiv.org/abs/1706.02216
作者:Hamilton, William L. and Ying, Rex and Leskovec, Jure(斯坦福)
來源:NIPS 2017
代碼:https://github.com/williamleif/graphsage-simple/

此文提出的方法叫GraphSAGE,針對的問題是之前的網絡表示學習的transductive,從而提出了一個inductive的GraphSAGE算法。GraphSAGE同時利用節點特徵信息和結構信息得到Graph Embedding的映射,相比之前的方法,之前都是保存了映射後的結果,而GraphSAGE保存了生成embedding的映射,可擴展性更強,對於節點分類和鏈接預測問題的表現也比較突出

2 GraphSAGE動機

第一點:大多數graph embedding框架是transductive(直推式的), 只能對一個固定的圖生成embedding。這種transductive的方法不能對圖中沒有的新節點生成embedding。

第二點:相對的,GraphSAGE是一個inductive(歸納式)框架,能夠高效地利用節點的屬性信息對新節點生成embedding。

這裏的transductive和inductive用的很精髓,統計機器學習可以分成兩種: transductive learning, inductive learning,這裏我們可以分別成爲直推學習和歸納學習。

  • transductive learning: To specific (test) cases, 指的是測試集是特定的(固定的樣本
  • inductive learning: 測試集不是特定的。一般我們的目的是做 inductive learning。

爲了搞懂 transductive learning和inductive learning,我們可以看下西方國家法律體系和大陸法系的區別:
(1)Transductive Learning:從彼個例到此個例,有點象英美法系,實際案例直接結合過往的判例進行判決。關注具體實踐。
(2)Inductive Learning:從多個個例歸納出普遍性,再演繹到個例,有點象大陸法系,先對過往的判例歸納總結出法律條文,再應用到實際案例進行判決。從有限的實際樣本中,企圖歸納出普遍真理,傾向形而上,往往會不由自主地成爲教條。

GNN中經典的DeepWalk, GCN方法都是transductive learning,大多數節點嵌入模型都基於頻譜分解/矩陣分解方法。而這些方法問題是矩陣分解方法本質上是transductive 的!簡而言之,transductive 方法在處理以前從未見過的數據時效果不佳。這些方法需要整個圖形結構的節點在訓練時都出現,以生成節點嵌入。如果之後有新的節點添加到Gparh,則需要重新訓練模型。而GraphSAGE方法學到的node embedding,是根據node的鄰居關係的變化而變化的,也就是說,即使是舊的node,如果建立了一些新的link,那麼其對應的embedding也會變化,而且也很方便地學到。

3 相關工作

GraphSAGE算法在概念上與以前的節點embedding方法、一般的圖形學習監督方法以及最近將卷積神經網絡應用於圖形結構化數據的進展有關。

3.1 Factorization-based embedding approaches(節點embedding)

一些node embedding方法使用隨機遊走的統計方法和基於矩陣分解學習目標學習低維的embeddings

  • Grarep: Learning graph representations with global structural information. In KDD, 2015
  • node2vec: Scalable feature learning for networks. In KDD, 2016
  • Deepwalk: Online learning of social representations. In KDD, 2014
  • Line: Large-scale information network embedding. In WWW, 2015
  • Structural deep network embedding. In KDD, 2016
    這些embedding算法直接訓練單個節點的節點embedding,本質上是transductive,而且需要大量的額外訓練(如隨機梯度下降)使他們能預測新的頂點。

此外,Yang et al.的Planetoid-I算法,是一個inductive的基於embedding的半監督學習算法。然而,Planetoid-I在推斷的時候不使用任何圖結構信息,而在訓練的時候將圖結構作爲一種正則化的形式。

不像前面的這些方法,本文利用特徵信息來訓練可以對未見過的頂點生成embedding的模型。

3.2 Supervised learning over graphs

Graph kernel
除了節點嵌入方法,還有大量關於圖結構數據的監督學習的文獻。這包括各種各樣的基於內核的方法,其中圖的特徵向量來自不同的圖內核(參見Weisfeiler-lehman graph kernels和其中的引用)。

一些神經網絡方法用於圖結構上的監督學習,本文的方法在概念上受到了這些算法的啓發

  • Discriminative embeddings of latent variable models for structured data. In - ICML, 2016
  • A new model for learning in graph domains
  • Gated graph sequence neural networks. In ICLR, 2015
  • The graph neural network model
    然而,這些以前的方法是嘗試對整個圖(或子圖)進行分類的,但是本文的工作的重點是爲單個節點生成有用的表示。

3.3 Graph convolutional networks

近年來,提出了幾種用於圖上學習的卷積神經網絡結構

  • Spectral networks and locally connected networks on graphs. In ICLR, 2014
    Convolutional neural networks on graphs with fast localized spectral filtering. In NIPS, 2016
  • Convolutional networks on graphs for learning molecular fingerprints. In NIPS,2015
  • Semi-supervised classification with graph convolutional networks. In ICLR, 2016
  • Learning convolutional neural networks for graphs. In ICML, 2016
    這些方法中的大多數不能擴展到大型圖,或者設計用於全圖分類(或者兩者都是)。

原文鏈接:https://blog.csdn.net/yyl424525/article/details/100532849

4 GraphSAGE 核心思想

GraphSAGE的核心:GraphSAGE不是試圖學習一個圖上所有node的embedding,而是學習一個爲每個node產生embedding的映射。


在上圖中,如果對《史酷比狗》劇情熟悉的話,我們很清楚第知道Fred,Velma,Daphne和Shaggy這些角色,我們可以回想下哪個角色與上面四個成員有關係呢?我們腦子裏第一印象應該是史酷比,所以說我們可以認爲史酷比的鄰居節點近似地表示了目標節點。

論文中提出的方法稱爲GraphSAGE, SAGE指的是 Sample and Aggregate,不是對每個頂點都訓練一個單獨的embeddding向量,而是訓練了一組aggregator functions,這些函數學習如何從一個頂點的局部鄰居聚合特徵信息。每個聚合函數從一個頂點的不同的hops或者說不同的搜索深度聚合信息。測試或是推斷的時候,使用訓練好的系統,通過學習到的聚合函數來對完全未見過的頂點生成embedding。


上面是爲紅色的目標節點生成embedding的過程。k表示距離目標節點的搜索深度,k=1就是目標節點的相鄰節點,k=2表示目標節點相鄰節點的相鄰節點。
對於上圖中的例子:

  • 第一步是採樣,k=1採樣了3個節點,對k=2採用了5個節點;
  • 第二步是聚合鄰居節點的信息,獲得目標節點的embedding;
  • 第三步是使用聚合得到的信息,也就是目標節點的embedding,來預測圖中想預測的信息;

5 GraphSAGE模型細節

GraphSAGE的目標是基於參數h的相鄰節點的某種組合來學習每個節點的表示形式。



稍微回顧下,Graph中的每個節點都可以擁有自己的特徵向量,該特徵向量由X節點特徵得到。現在讓我們假設每個節點的所有特徵向量都具有相同的大小。一層GraphSAGE可以運行k次迭代-因此,每k次迭代,每個節點都有一個節點表示h。


其中:
X_{v}代表某個節點v的輸入特徵
h_{v}^{0}代表節點v的初始化向量表示
h_{v}^{k}代表節點vk次迭代之後的向量表示
z_{v}代表某個節點v經過GraphSAGE模型之後的最終輸出向量

因爲每個節點都可以由它們的鄰居近似表示,所以節點A的嵌入可以用其鄰近節點嵌入向量的某種組合來表示。 通過一輪GraphSAGE算法,我們將獲得節點A的新表示形式。原始圖中的所有節點都遵循相同的過程。

GraphSAGE算法遵循兩步過程。由於它是迭代的,因此存在一個初始化步驟,該步驟將所有初始節點嵌入向量設置爲其特徵向量。(k從1…K開始迭代)


步驟1 Aggregate


aggregator 的作用是把一個向量的集合轉換成向量,也就是聚合。和其他機器學習任務中的數據(如圖像,文本等)不同,圖中的節點是沒有順序的(node’s neighbors have no natural ordering),aggregator function操作的是一個無序的向量集合\{h_{u}^{k-1},\forall u\in{N(v)}\}。其中N(v)代表了節點v的鄰居節點集合。
這篇文章嘗試了多種aggregator function:

  • Mean aggregator:顯然對向量集合,對應元素取均值是最直接的想法。
  • LSTM aggregator:和mean aggregator相比,LSTM有更大的表達能力。但是LSTM不符合symmetric的性質,輸入是有順序的。所以把相鄰節點的向量集合隨機打亂順序,然後作爲LSTM的輸入。
  • Pooling aggregator:嘗試了pooling做aggregator, 所有相鄰節點的向量共享權重,先經過一個非線性全連接層,然後做max-pooling.

爲說明起見,請觀察下圖。與其將節點B的表示初始化爲其特徵向量,我們實際上可以運行此聚合更新功能來基於節點B的鄰居獲取節點B的表示形式。我們可以對k = 1層中的節點C和D執行相同的操作。在k = 0層中,我們將初始化嵌入其初始特徵向量的鄰居節點。



在上面的示例中,我們簡單地設置k = 2並使用節點A的鄰居和鄰居鄰居獲得最終的目標節點表示形式。您可能會嘗試使用多個鄰域,即更大的k值。但是,太多的鄰域可能會稀釋節點v的節點表示形式,但是太少的鄰域(少於2個)可能類似於不使用GNN而是隻使用MLP而已–值得深思

步驟2 Update

在基於節點v的鄰居獲得聚合表示後,請使用其先前表示和聚合表示的組合來更新當前節點v。該f_update功能爲任何可微函數,可以再次,是一樣簡單的平均函數,或複雜如神經網絡。

根據節點v的鄰域聚合表示和節點v的先前表示,爲節點v創建更新的表示:

因此,現在再理解原始論文中的以下算法片段時,我們應該沒有問題了:


關於本文實現的一些注意事項:
第4行:作者嘗試了多種聚合器功能,包括使用最大池,均值聚合甚至LSTM聚合。LSTM聚合方法要求每個k迭代都要對節點進行混洗,以便在計算聚合時暫時不偏向任何一個節點。
第4行:在本文中,我們概括爲f_aggregate的內容實際上表示爲AGGREGATE_k。
第5行:本文中的f_update函數是一個串聯操作。因此,級聯後,輸出的形狀爲尺寸(2F,1)。級聯的輸出通過權重矩陣W ^ k的矩陣乘法進行變換。該權重矩陣旨在將輸出的維數減小爲(F,1)。最後,級聯和變換後的節點嵌入向量經歷非線性。
第5行:每個k迭代都有一個單獨的權重矩陣。這具有學習權重的解釋,該權重具有多個鄰域對目標節點的重要性的感覺。
第7行:通過除以矢量範數來標準化節點嵌入,以防止梯度爆炸。

6 模型訓練-無監督損失函數

那麼,如何實際訓練GraphSAGE GNN?
作者訓練了無監督和有監督的GraphSAGE模型。有監督的設置遵循針對節點分類任務的常規交叉熵樣式預測。但是,無監督的情況會嘗試通過執行以下損失函數來保留圖結構:


損失函數的藍色部分試圖強制說明,如果節點u和v在實際圖中接近,則它們的節點嵌入在語義上應該相似。在理想情況下,我們期望z_{u}z_{v}的內積很大。如此大的數值輸入到sigmoid輸出會接近1log(1)= 0

損失函數的粉紅色部分試圖強制執行相反的操作!也就是說,如果節點u和v在實際圖形中實際上相距較遠,則我們期望它們的節點嵌入是不同的/相反的。在理想情況下,我們期望z_{u}z_{v}的內積爲較大的負數。可以解釋爲,嵌入z_{u}z_{v}差別很大,以至於它們之間的距離大於90度。兩個大負數的乘積變成一個大正數。如此大的數值輸入到sigmoid輸出會接近1log(1)=0。由於可能有更多的節點u遠離我們的目標節點v在圖中,我們從遠離節點v的節點分佈中僅採樣了幾個負節點u:P_{n}(v)。這樣可以確保訓練時的損失功能達到平衡。

另外添加epsilon可以確保我們永遠不會取log(0)

7 實驗結果

實驗給了三個圖,效果,效率,採樣數量對效果和性能的影響。

三個數據集上的實驗結果表明,一般是LSTM或pooling效果比較好。有監督都比無監督好。


8 代碼

作者在論文裏用的tensorflow,但是也開源了一個簡單, 容易擴展的pytorch版本。
pytorch版本中用的兩個數據集都比較小,不是論文裏用的數據集。這兩個數據集在Kipf 16年經典的GCN論文用到了。節點數量分別約是2700,20000。

cora是一個機器學習論文引用數據集,提供了2708篇論文的引用關係,每篇論文的label是論文所屬的領域。label一共七種,包括遺傳算法,神經網絡,強化學習等7個領域。特徵是已經經過stemming和stopwords處理過的詞表,每列表示一個詞是否出現。
aggregators核心代碼:

import torch
import torch.nn as nn
from torch.autograd import Variable

import random

"""
Set of modules for aggregating embeddings of neighbors.
"""

class MeanAggregator(nn.Module):
    """
    Aggregates a node's embeddings using mean of neighbors' embeddings
    """
    def __init__(self, features, cuda=False, gcn=False): 
        """
        Initializes the aggregator for a specific graph.
        features -- function mapping LongTensor of node ids to FloatTensor of feature values.
        cuda -- whether to use GPU
        gcn --- whether to perform concatenation GraphSAGE-style, or add self-loops GCN-style
        """

        super(MeanAggregator, self).__init__()

        self.features = features
        self.cuda = cuda
        self.gcn = gcn
        
    def forward(self, nodes, to_neighs, num_sample=10):
        """
        nodes --- list of nodes in a batch
        to_neighs --- list of sets, each set is the set of neighbors for node in batch
        num_sample --- number of neighbors to sample. No sampling if None.
        """
        # Local pointers to functions (speed hack)
        _set = set
        if not num_sample is None:
            _sample = random.sample
            samp_neighs = [_set(_sample(to_neigh, 
                            num_sample,
                            )) if len(to_neigh) >= num_sample else to_neigh for to_neigh in to_neighs]
        else:
            samp_neighs = to_neighs

        if self.gcn:
            samp_neighs = [samp_neigh + set([nodes[i]]) for i, samp_neigh in enumerate(samp_neighs)]
        unique_nodes_list = list(set.union(*samp_neighs))
        unique_nodes = {n:i for i,n in enumerate(unique_nodes_list)}
        mask = Variable(torch.zeros(len(samp_neighs), len(unique_nodes)))
        column_indices = [unique_nodes[n] for samp_neigh in samp_neighs for n in samp_neigh]   
        row_indices = [i for i in range(len(samp_neighs)) for j in range(len(samp_neighs[i]))]
        mask[row_indices, column_indices] = 1
        if self.cuda:
            mask = mask.cuda()
        num_neigh = mask.sum(1, keepdim=True)
        mask = mask.div(num_neigh)
        if self.cuda:
            embed_matrix = self.features(torch.LongTensor(unique_nodes_list).cuda())
        else:
            embed_matrix = self.features(torch.LongTensor(unique_nodes_list))
        to_feats = mask.mm(embed_matrix)
        return to_feats

Encoder節點編碼

import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F

class Encoder(nn.Module):
    """
    Encodes a node's using 'convolutional' GraphSage approach
    """
    def __init__(self, features, feature_dim, 
            embed_dim, adj_lists, aggregator,
            num_sample=10,
            base_model=None, gcn=False, cuda=False, 
            feature_transform=False): 
        super(Encoder, self).__init__()

        self.features = features
        self.feat_dim = feature_dim
        self.adj_lists = adj_lists
        self.aggregator = aggregator
        self.num_sample = num_sample
        if base_model != None:
            self.base_model = base_model

        self.gcn = gcn
        self.embed_dim = embed_dim
        self.cuda = cuda
        self.aggregator.cuda = cuda
        self.weight = nn.Parameter(
                torch.FloatTensor(embed_dim, self.feat_dim if self.gcn else 2 * self.feat_dim))
        init.xavier_uniform(self.weight)

    def forward(self, nodes):
        """
        Generates embeddings for a batch of nodes.
        nodes     -- list of nodes
        """
        neigh_feats = self.aggregator.forward(nodes, [self.adj_lists[int(node)] for node in nodes], 
                self.num_sample)
        if not self.gcn:
            if self.cuda:
                self_feats = self.features(torch.LongTensor(nodes).cuda())
            else:
                self_feats = self.features(torch.LongTensor(nodes))
            combined = torch.cat([self_feats, neigh_feats], dim=1)
        else:
            combined = neigh_feats
        combined = F.relu(self.weight.mm(combined.t()))
        return combined

GraphSAGE訓練模型

import torch
import torch.nn as nn
from torch.nn import init
from torch.autograd import Variable

import numpy as np
import time
import random
from sklearn.metrics import f1_score
from collections import defaultdict

from graphsage.encoders import Encoder
from graphsage.aggregators import MeanAggregator

"""
Simple supervised GraphSAGE model as well as examples running the model
on the Cora and Pubmed datasets.
"""

class SupervisedGraphSage(nn.Module):

    def __init__(self, num_classes, enc):
        super(SupervisedGraphSage, self).__init__()
        self.enc = enc
        self.xent = nn.CrossEntropyLoss()

        self.weight = nn.Parameter(torch.FloatTensor(num_classes, enc.embed_dim))
        init.xavier_uniform(self.weight)

    def forward(self, nodes):
        embeds = self.enc(nodes)
        scores = self.weight.mm(embeds)
        return scores.t()

    def loss(self, nodes, labels):
        scores = self.forward(nodes)
        return self.xent(scores, labels.squeeze())

9 參考資料

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章