GraphSAGE
GCN是一種在圖中結合拓撲結構和頂點屬性信息學習頂點的embedding表示的方法。然而GCN要求在一個確定的圖中去學習頂點的embedding,無法直接泛化到在訓練過程沒有出現過的頂點,即屬於一種直推式(transductive)的學習。
本文介紹的GraphSAGE則是一種能夠利用頂點的屬性信息高效產生未知頂點embedding的一種歸納式(inductive)學習的框架。
其核心思想是通過學習一個對鄰居頂點進行聚合表示的函數來產生目標頂點的embedding向量/
GraphSAGE算法原理
GraphSAGE 是Graph SAmple and aggreGatE的縮寫,其運行流程如上圖所示,可以分爲三個步驟
-
對圖中每個頂點鄰居頂點進行採樣
-
根據聚合函數聚合鄰居頂點蘊含的信息
-
得到圖中各頂點的向量表示供下游任務使用
採樣鄰居頂點
出於對計算效率的考慮,對每個頂點採樣一定數量的鄰居頂點作爲待聚合信息的頂點。設採樣數量爲k,若頂點鄰居數少於k,則採用有放回的抽樣方法,直到採樣出k個頂點。若頂點鄰居數大於k,則採用無放回的抽樣。
當然,若不考慮計算效率,我們完全可以對每個頂點利用其所有的鄰居頂點進行信息聚合,這樣是信息無損的。
生成向量的僞代碼
這裏K是網絡的層數,也代表着每個頂點能夠聚合的鄰接點的跳數,如K=2的時候每個頂點可以最多根據其2跳鄰接點的信息學習其自身的embedding表示。
聚合函數的選取
由於在圖中頂點的鄰居是天然無序的,所以我們希望構造出的聚合函數是對稱的(即改變輸入的順序,函數的輸出結果不變),同時具有較高的表達能力。
MEAN aggregator
上式對應於僞代碼中的第4-5行,直接產生頂點的向量表示,而不是鄰居頂點的向量表示。 mean aggregator將目標頂點和鄰居頂點的第k-1層向量拼接起來,然後對向量的每個維度進行求均值的操作,將得到的結果做一次非線性變換產生目標頂點的第k層表示向量。
Pooling aggregator
Pooling aggregator 先對目標頂點的鄰接點表示向量進行一次非線性變換,之後進行一次pooling操作(maxpooling or meanpooling),將得到結果與目標頂點的表示向量拼接,最後再經過一次非線性變換得到目標頂點的第k層表示向量。
LSTM aggregator
LSTM相比簡單的求平均操作具有更強的表達能力,然而由於LSTM函數不是關於輸入對稱的,所以在使用時需要對頂點的鄰居進行一次亂序操作。
參數的學習
1.無監督學習形式
基於圖的損失函數希望臨近的頂點具有相似的向量表示,同時讓分離的頂點的表示儘可能區分。
2.監督學習形式
監督學習形式根據任務的不同直接設置目標函數即可,如最常用的節點分類任務使用交叉熵損失函數。