在上一篇文章中介紹了GCN
【Graph Neural Network】GCN: 算法原理,實現和應用
GCN是一種在圖中結合拓撲結構和頂點屬性信息學習頂點的embedding表示的方法。然而GCN要求在一個確定的圖中去學習頂點的embedding,無法直接泛化到在訓練過程沒有出現過的頂點,即屬於一種直推式(transductive)的學習。
本文介紹的GraphSAGE則是一種能夠利用頂點的屬性信息高效產生未知頂點embedding的一種歸納式(inductive)學習的框架。
其核心思想是通過學習一個對鄰居頂點進行聚合表示的函數來產生目標頂點的embedding向量。
GraphSAGE算法原理在這裏插入圖片描述
GraphSAGE 是Graph SAmple and aggreGatE的縮寫,其運行流程如上圖所示,可以分爲三個步驟
對圖中每個頂點鄰居頂點進行採樣
根據聚合函數從聚合鄰居頂點蘊含的信息
得到圖中各頂點的向量表示供下游任務
採樣鄰居頂點
出於對計算效率的考慮,對每個頂點採樣一定數量的鄰居頂點作爲待聚合信息的頂點。設採樣數量爲k,若頂點鄰居數少於k,則採用有放回的抽樣方法,直到採樣出k個頂點。若頂點鄰居數大於k,則採用無放回的抽樣。
當然,若不考慮計算效率,我們完全可以對每個頂點利用其所有的鄰居頂點進行信息聚合,這樣是信息無損的。
生成向量的僞代碼
在這裏插入圖片描述
這裏K是網絡的層數,也代表着每個頂點能夠聚合的鄰接點的跳數,如K=2的時候每個頂點可以最多根據其2跳鄰接點的信息學習其自身的embedding表示。
在每一層的循環k中,對每個頂點v,首先使用v的鄰接點的k-1層的embedding表示來產生其鄰居頂點的第k層聚合表示hkN(v)
hN(v)k,之後將hkN(v)hN(v)k和頂點v的第k-1層表示進行拼接,經過一個非線性變換產生頂點v的第k層embedding表示hkv
hvk。
聚合函數的選取
由於在圖中頂點的鄰居是天然無序的,所以我們希望構造出的聚合函數是對稱的(即改變輸入的順序,函數的輸出結果不變),同時具有較高的表達能力。
MEAN aggregator
hkv←σ(W⋅MEAN({hk−1v}∪{hk−1u,∀u∈N(v)})
hvk←σ(W⋅MEAN({hvk−1}∪{huk−1,∀u∈N(v)})
上式對應於僞代碼中的第4-5行,直接產生頂點的向量表示,而不是鄰居頂點的向量表示。
mean aggregator將目標頂點和鄰居頂點的第k-1層向量拼接起來,然後對向量的每個維度進行求均值的操作,將得到的結果做一次非線性變換產生目標頂點的第k層表示向量。
Pooling aggregator
AGGREGATEpoolk=max({σ(Wpoolhkui+b),∀ui∈N(v)})
AGGREGATEkpool=max({σ(Wpoolhuik+b),∀ui∈N(v)})
Pooling aggregator 先對目標頂點的鄰接點表示向量進行一次非線性變換,之後進行一次pooling操作(maxpooling or meanpooling),將得到結果與目標頂點的表示向量拼接,最後再經過一次非線性變換得到目標頂點的第k層表示向量。
LSTM aggregator
LSTM相比簡單的求平均操作具有更強的表達能力,然而由於LSTM函數不是關於輸入對稱的,所以在使用時需要對頂點的鄰居進行一次亂序操作。
參數的學習
在定義好聚合函數之後,接下來就是對函數中的參數進行學習。文章分別介紹了無監督學習和監督學習兩種方式。
無監督學習形式
基於圖的損失函數希望臨近的頂點具有相似的向量表示,同時讓分離的頂點的表示儘可能區分。
目標函數如下
在這裏插入圖片描述
其中v是通過固定長度的隨機遊走出現在u附近的頂點,pn
pn是負採樣的概率分佈,Q
Q是負樣本的數量。與DeepWalk不同的是,這裏的頂點表示向量是通過聚合頂點的鄰接點特徵產生的,而不是簡單的進行一個embedding lookup操作產生。
監督學習形式
監督學習形式根據任務的不同直接設置目標函數即可,如最常用的節點分類任務使用交叉熵損失函數。
GraphSAGE的實現
這裏以MEAN aggregator簡單講下聚合函數的實現
MEAN aggregator
features, node, neighbours = inputs
node_feat = tf.nn.embedding_lookup(features, node)
neigh_feat = tf.nn.embedding_lookup(features, neighbours)
concat_feat = tf.concat([neigh_feat, node_feat], axis=1)
concat_mean = tf.reduce_mean(concat_feat,axis=1,keep_dims=False)
output = tf.matmul(concat_mean, self.neigh_weights)
if self.use_bias:
output += self.bias
if self.activation:
output = self.activation(output)
對於第k層的aggregator,features爲第k−1
k−1層所有頂點的向量表示矩陣,node和neighbours分別爲第k層採樣得到的頂點集合及其對應的鄰接點集合。
首先通過embedding_lookup操作獲取得到頂點和鄰接點的第k−1
k−1層的向量表示。然後通過concat將他們拼接成一個(batch_size,1+neighbour_size,embeding_size)的張量,使用reduce_mean對每個維度求均值得到一個(batch_size,embedding_size)的張量。
最後經過一次非線性變換得到output,即所有頂點的第k層的表示向量
GraphSAGE
下面是完整的GraphSAGE方法的代碼
def GraphSAGE(feature_dim, neighbor_num, n_hidden, n_classes, use_bias=True, activation=tf.nn.relu,
aggregator_type='mean', dropout_rate=0.0, l2_reg=0):
features = Input(shape=(feature_dim,))
node_input = Input(shape=(1,), dtype=tf.int32)
neighbor_input = [Input(shape=(l,),dtype=tf.int32) for l in neighbor_num]
if aggregator_type == 'mean':
aggregator = MeanAggregator
else:
aggregator = PoolingAggregator
h = features
for i in range(0, len(neighbor_num)):
if i > 0:
feature_dim = n_hidden
if i == len(neighbor_num) - 1:
activation = tf.nn.softmax
n_hidden = n_classes
h = aggregator(units=n_hidden, input_dim=feature_dim, activation=activation, l2_reg=l2_reg, use_bias=use_bias,
dropout_rate=dropout_rate, neigh_max=neighbor_num[i])(
[h, node_input,neighbor_input[i]])#
output = h
input_list = [features, node_input] + neighbor_input
model = Model(input_list, outputs=output)
return model
其中feature_dim表示頂點屬性特徵向量的維度,neighbor_num是一個list表示每一層抽樣的鄰居頂點的數量,n_hidden爲聚合函數內部非線性變換時的參數矩陣的維度,n_classes表示預測的類別的數量,aggregator_type爲使用的聚合函數的類別。
GraphSAGE應用
本例中的訓練,評測和可視化的完整代碼在下面的git倉庫中
https://github.com/shenweichen/GraphNeuralNetwork
這裏我們使用引文網絡數據集Cora進行測試,Cora數據集包含2708個頂點, 5429條邊,每個頂點包含1433個特徵,共有7個類別。
按照論文的設置,從每個類別中選取20個共140個頂點作爲訓練,500個頂點作爲驗證集合,1000個頂點作爲測試集。
採樣時第1層採樣10個鄰居,第2層採樣25個鄰居。
節點分類任務結果
通過多次運行準確率在0.80-0.82之間。
節點向量可視化
從對得到的節點向量的可視化結果來看,GCN得到的向量相比於DeepWalk產出的向量確實更加能夠將同類的頂點聚集在一起,不同類的頂點區分開來。
參考資料
Hamilton W, Ying Z, Leskovec J. Inductive representation learning on large graphs[C]//Advances in Neural Information Processing Systems. 2017: 1024-1034.(https://papers.nips.cc/paper/6703-inductive-representation-learning-on-large-graphs.pdf)