【GNN】MPNN:消息傳遞神經網絡

今天學習的是谷歌大腦的同學 2017 年的工作《Neural Message Passing for Quantum Chemistry》,也就是我們經常提到的消息傳遞網絡(Message Passing Neural Network,MPNN),目前引用數超過 900 次。

嚴格來說,MPNN 不是一個模型,而是一個框架。作者在這篇論文中主要將現有模型抽象其共性並提出成 MPNN 框架,同時利用 MPNN 框架在分子分類預測中取得了一個不錯的成績。

1.Introduction

深度學習被廣泛應用於圖像、音頻、NLP 等領域,但在化學任務(分子分類等)中仍然使用中機器學習+特徵工程的方式,其主要原因在於目前尚未有工作證明深度學習在這個領域能取得很大的成功。

近年來,隨着量子化學計算和分子動力學模擬等實驗的展開產生了巨大的數據量,大多數經典的技術都無法有效利用目前的大數據集。而原子系統的對稱性表明,能夠應用於網絡圖中的神經網絡也能夠應用於分子模型。所以,找到一個更加強大的模型來解決目前的化學任務可以等價於找到一個適用於網絡的模型。

在這篇論文中,作者的目標是證明:能夠應用於化學預測任務的模型可以直接從分子圖中學習到分子的特徵,並且不受到圖同構的影響。爲此,作者將應用於圖上的監督學習框架稱之爲消息傳遞神經網絡(MPNN),這種框架是從目前比較流行的支持圖數據的神經網絡模型中抽象出來的一些共性,抽象出來的目的在於理解它們之間的關係。

鑑於目前已經有很多類似 MPNN 框架的模型,所以作者呼籲學者們應該將這個方法應用到實際的應用中,並且通過實際的應用來提出模型的改進版本,儘可能的去推廣模型的實際應用。

本文給出的一個例子是利用 MPNN 框架代替計算代價昂貴的 DFT 來預測有機分子的量子特性:

2.MPNN

本節內容分爲兩塊,一塊是看下作者如何從現有模型中抽象出 MPNN 框架,另一塊是看下作者如何利用 MPNN 框架去解決實際問題。

2.1 MPNN framework

我們先來介紹下 MPNN 這一通用框架,並通過八篇文獻來舉例驗證 MPNN 框架的通配性。

簡單起見,我們考慮無向圖 G,節點 v 的特徵爲 xvx_v,邊的特徵爲 evwe_{vw}。前向傳遞有兩個階段:一個是消息傳遞階段(Message Passing),另一個是讀出階段(Readout)。考慮消息傳遞階段,消息函數定義爲 MtM_t,頂點更新函數定義爲 UtU_t,t 爲運行的時間步。在消息傳遞過程中,隱藏層節點 v 的狀態 hvth_v^t 可以被基於 mvt+1m_v^{t+1} 進行更新:
mvt+1=wN(v)Mt(hvt,hwt,evw)hvt+1=Ut(hvt,mvt+1) \begin{aligned} m_v^{t+1} &= \sum_{w\in N(v)}M_t(h_v^t, h_w^t,e_{vw}) \\ h_v^{t+1} &= U_t(h_v^t, m_v^{t+1}) \end{aligned} \\
其中,N(v)N(v) 表示圖 G 中節點 v 的鄰居。

讀出階段使用一個讀出函數 R 來計算整張圖的特徵向量:
y^=R(hvTvG) \hat y = R({h_v^T | v \in G}) \\
消息函數 MtM_t,向量更新函數 UtU_t 和讀出函數 RR 都是可微函數。RR 作用於節點的狀態集合,同時對節點的排列不敏感,這樣才能保證 MPNN 對圖同構保持不變。

此外,我們也可以通過引入邊的隱藏層狀態來學習圖中的每一條邊的特徵,並且同樣可以用上面的等式進行學習和更新。

接下來我們看下如何通過定義消息函數更新函數讀出函數來適配不同種模型。

Paper 1 : Convolutional Networks for Learning Molecular Fingerprints, Duvenaud et al. (2015)

這篇論文中消息函數爲:
M(hv,hw,evw)=(hw,evw) M(h_v, h_w,e_{vw}) = (h_w,e_{vw}) \\
其中 (.,.)(.,.) 表示拼接(concat);

節點的更新函數爲:
Ut(hvt,mvt+1)=σ(Htdeg(v)mvt+1) U_t(h_v^t,m_v^{t+1}) = \sigma(H_t^{deg(v)}m_v^{t+1}) \\
其中 σ\sigma 爲 sigmoid 函數,deg(v)deg(v) 表示節點 v 的度,HtvH_t^v 是一個可學習的矩陣,t 爲時間步,N 爲節點度;

讀出函數 R 將先前所有隱藏層的狀態 hvth_v^t 進行連接:
R=f(v,tsoftmax(Wthvt)) R = f(\sum_{v,t}softmax(W_th_v^t)) \\
其中 f 是一個神經網絡,WtW_t 是一個可學習的讀出矩陣。

這種消息傳遞階段可能會存在一些問題,比如說最終的消息向量分別對連通的節點和連通的邊求和 mvt+1=(hwt,evw)m_v^{t+1}=(\sum h_w^t,\sum e_{vw}) 。由此可見,該模型實現的消息傳遞無法識別節點和邊之間的相關性。

Paper 2 : Gated Graph Neural Networks (GG-NN), Li et al. (2016)

這篇論文比較有名,作者後續也是在這個模型的基礎上進行改進的。

GG-NN 使用的消息函數爲:
Mt(hvt,hwt,evw)=Aevwhwt M_t(h_v^t,h_w^t,e_{vw})=A_{e_{vw}}h_w^t \\
其中 AevwA_{e_{vw}}evwe_{vw} 的一個可學習矩陣,每條邊都會對應那麼一個矩陣;

更新函數爲:
Ut(hvt,mvt+1)=GRU(hvt,mvt+1) U_t(h_v^t,m_v^{t+1}) = GRU(h_v^t, m_v^{t+1}) \\
其中 GRUGRU 爲門控制單元(Gate Recurrent Unit)。該工作使用了權值捆綁,所以在每一個時間步 t 下都會使用相同的更新函數;

讀出函數 R 爲:
R=vVσ(i(hv(T)),hv0)    (j(hv(T))) R=\sum_{v\in V} \sigma(i(h_v^{(T)}),h_v^0)\; \odot \; (j(h_v^{(T)})) \\
其中 i 和 j 爲神經網絡,\odot 表示元素相乘。

Paper 3 : Interaction Networks, Battaglia et al. (2016)

這篇論文考慮圖中的節點和圖結構,同時也考慮每個時間步下的節點級的影響。這種情況下更新函數的輸入會多一些 (hv,xv,mv)(h_v,x_v,m_v),其中 $x_v $ 是一個外部向量,表示對頂點 v 的一些外部影響。

這篇論文的消息函數 M(hv,hw,evw)M(h_v,h_w,e_{vw}) 是一個以 (hv,hw,evw)(h_v,h_w,e_{vw}) 爲輸入的神經網絡,節點更新函數 U(hv,xv,mv)U(h_v,x_v,m_v) 是一個以 (hv,xv,mv)(h_v,x_v,m_v) 爲輸入的神經網絡,最終會有一個圖級別的輸出 R=f(vGhvT)R=f(\sum_{v\in G}h_v^T) ,其中 f 是一個神經網絡,輸入是最終的隱藏層狀態的和。在原論文中 T=1T=1

Paper 4 : Molecular Graph Convolutions, Kearnes et al. (2016)

這篇論文與其他 MPNN 稍微有些不同,主要區別在於考慮了邊表示 ev,wte_{v,w}^t,並且在消息傳遞階段會進行更新。

消息傳遞函數用的是節點的消息:
Mt(hvt,hwt,evwt)=evwt M_t(h_v^t,h_w^t,e_{vw}^t)=e_{vw}^t
節點的更新函數爲:
Ut(hvt,mvt+1)=α(W1(α(W0hvt),mvt+1)) U_t(h_v^t,m_v^{t+1}) = \alpha(W_1(\alpha(W_0h_v^t),m_v^{t+1}))
其中 (.,.)(.,.) 表示拼接(concat),α\alpha 爲 ReLU 激活函數,W0,W1W_0,W_1 爲可學習權重矩陣;

邊狀態的更新定義爲:
evwt+1=Ut(evwt,hvt,hwt)=α(W4(α(W2,evwt),α(W3(hvt,hwt)))) \begin{aligned} e_{vw}^{t+1} &= U_t^{'}(e_{vw}^t, h_v^t, h_w^t) \\ &= \alpha(W_4(\alpha (W_2,e_{vw}^t), \alpha(W_3(h_v^t,h_w^t)))) \end{aligned} \\
其中,WiW_i 爲可學習權重矩陣。

Paper 5 : Deep Tensor Neural Networks, Schutt et al. (2017)

消息函數爲:
Mt=tanh(Wfc((Wcfhwt+b1)(Wdfevw+b2))) M_t = tanh(W^{fc}((W^{cf}h_w^t+b_1) \odot(W^{df}e_{vw}+b_2))) \\
其中 Wfc,Wcf,WdfW^{fc},W^{cf},W^{df} 爲矩陣,b1,b2b_1,b_2 爲偏置向量;

更新函數爲:
Ut(hvt,mvt+1)=hvt+mvt+1 U_t(h_v^t,m_v^{t+1}) = h_v^t + m_v^{t+1} \\
讀出函數通過單層隱藏層接受每個節點並且求和後輸出:
R=vNN(hvT) R = \sum_v NN(h_v^T) \\
Paper 6 : Laplacian Based Methods, Bruna et al. (2013); Defferrard et al. (2016); Kipf & Welling (2016)

基於拉普拉斯矩陣的方法將圖像中的卷積運算擴展到網絡圖 G 的鄰接矩陣 A 中。

在 Bruna et al. (2013); Defferrard et al. (2016); 的工作中,消息函數爲:
Mt(hvt,hwt)=Cvwthwt M_t(h_v^t,h_w^t) = C_{vw}^t h_w^t \\
其中,矩陣 CvwtC_{vw}^t 爲拉普拉斯矩陣 L 的特徵向量組成的矩陣;

節點的更新函數爲:
Ut(hvt,mvt+1)=σ(mvt+1) U_t(h_v^t, m_v^{t+1}) = \sigma(m_v^{t+1}) \\
其中,σ\sigma 爲非線性的激活函數,比如說 ReLU。

在 Kipf & Welling (2016) 的工作中,消息函數爲:
Mt(hvt,hwt)=Cvwhwt M_t(h_v^t,h_w^t) = C_{vw} h_w^t \\
其中,Cvw=(deg(v)deg(w))1/2AvwC_{vw} = (deg(v)deg(w))^{-1/2}A_{vw}

節點的更新函數爲:
Uvt(hvt,mvt+1)=ReLU(Wtmvt+1) U_v^t(h_v^t, m_v^{t+1}) = ReLU(W^t m_v^{t+1}) \\
可以看到以上模型都是 MPNN 框架的不同實例,所以作者呼籲大家應該致力於將這一框架應用於某個實際應用,並根據不同情況對關鍵部分進行修改,從而引導模型的改進,這樣才能最大限度的發揮模型的能力。

2.2 MPNN Variants

本節來介紹下作者將 MPNN 框架應用於分子預測領域,提出了 MPNN 的變種,並以 QM9 數據集爲例進行了實驗。

QM9 數據集中的分子大多數由碳氫氧氮等元素組成,並組成了約 134k 個有機分子,可以劃分爲四大類(具體類別不介紹了),任務是根據分子結構預測分子所屬類別。

作者主要是基於 GG-NN 來探索 MPNN 的多種改進方式(不同的消息函數、輸出函數等),之所以用 GG-NN 是因爲這是一個很強的 baseline。

2.2.1 Message Functions

首先來看下消息函數,可以以 GG-NN 中使用的消息函數開始,GG-NN 用的是矩陣乘法:
M(hv,hw,evw)=Aevwhw M(h_v,h_w,e_{vw}) = A_{e_{vw}}h_w \\
爲了兼容邊特徵,作者提出了新的消息函數:
M(hv,hw,evw)=A(evw)hw M(h_v,h_w,e_{vw}) = A(e_{vw})h_w \\
其中,A(evw)A(e_{vw}) 是將邊的向量 evwe_{vw} 映射到 d×d 維矩陣的神經網絡。

矩陣乘法有一個特點,從節點 w 到節點 v 的函數僅與隱藏層狀態 hwh_w 和邊向量 evwe_{vw} 有關,而和隱藏狀態 hvth_v^t 無關。理論上來說,如果節點消息同時依賴於源節點 w 和目標節點 v 的話,網絡的消息通道將會得到更有效的利用。所以也可以嘗試去使用一種消息函數的變種:
mvw=f(hwt,hvt,evw) m_{vw} = f(h_w^t, h_v^t, e_{vw}) \\
其中,f 爲神經網絡。

2.2.2 Virtual Graph Elements

其次看來下消息傳遞,作者探索了兩種不同的消息傳遞方式。

最簡單的修改就是爲沒有連接的節點添加一個虛擬的邊,這樣消息便具有了更長的傳播距離;

此外,作者也嘗試了使用潛在的“主”節點(master node),這個節點可以通過特殊的邊來連接到圖中任意一個節點。主節點充當了一個全局的暫存空間,每個節點都會在消息傳遞過程中通過主節點進行讀取和寫入。同時允許主節點具有自己的節點維度,以及內部更新函數(GRU)的單獨權重。其目的同樣是爲了在傳播階段傳播很長的距離。

2.2.3 Readout Functions

然後來看下讀出函數,作者同樣嘗試了兩種讀出函數:

首先是 GG-NN 中的讀出函數:
R=vVσ(i(hv(T)),hv0)    (j(hv(T))) R=\sum_{v\in V} \sigma(i(h_v^{(T)}),h_v^0)\; \odot \; (j(h_v^{(T)})) \\
此外也考慮 set2set 模型。set2set 模型是專門爲在集合運算而設計的,並且相比簡單累加節點的狀態來說具有更強的表達能力。模型首先通過線性映射將數據映射到元組 (hvt,xv)(h_v^t, x_v) ,並將投影元組作爲輸入 T={(hvT,xv)}T=\{(h_v^T,x_v) \},然後經過 M 步計算後,set2set 模型會生成一個與節點順序無關的 Graph-level 的 embeedding 向量,從而得到我們的輸出向量。

2.2.4 Multiple Towers

最後考慮下 MPNN 的伸縮性。

對一個稠密圖來說,消息傳遞階段的每一個時間步的時間複雜度爲 O(n2d2)O(n^2d^2),其中 n 爲節點數,d 爲向量維度,可以看到時間複雜度還是非常高的。

爲了解決這個問題作者將向量維度 d 拆分成 k 份,就變成了 k 個 d/k 維向量,並在傳播過程中每個子向量分別進行傳播和更新,最後再進行合併。此時的子向量時間複雜度爲 O(n2(d/k)2)O(n^2(d/k)^2),考慮 k 個子向量的時間複雜度爲 O(n2d2/k)O(n^2d^2/k)

2.3 Input Representation

這一節主要介紹 GNN 的輸入。

對於分子來說有很多可以提取的特徵,比如說原子組成、化學鍵等,詳細的特徵列表如下圖所示:

對於鄰接矩陣,作者模型嘗試了三種邊表示形式:

化學圖(Chemical Graph):在不考慮距離的情況下,鄰接矩陣的值是離散的鍵類型:單鍵,雙鍵,三鍵或芳香鍵;

距離分桶(Distance bins):基於矩陣乘法的消息函數的前提假設是邊信息是離散的,因此作者將鍵的距離分爲 10 個 bin,比如說 [2,6] 中均勻劃分 8 個 bin,[0,2] 爲 1 個 bin,[6, +∞] 爲 1 個 bin;

原始距離特徵(Raw distance feature):也可以同時考慮距離和化學鍵的特徵,這時每條邊都有自己的特徵向量,此時鄰接矩陣的每個實例都是一個 5 維向量,第一維是距離,其餘思維是四種不同的化學鍵。

4.Experiment

來看一下實驗結果,以 QM-9 數據集爲例,共包含 130462 個分子,以 MAE 爲評估指標。

下圖爲現有算法和作者改進的算法之間的對比:

下圖爲不考慮空間信息的結果:

下圖爲考慮多塔模型和結果:

5.Conclusion

總結:作者從諸多模型中抽離出了 MPNN 框架,並且通過實驗表明,具有消息函數、更新函數和讀出函數的 MPNN 具有良好的歸納能力,可以用於預測分析特性,優於目前的 Baseline,並且無需進行復雜的特徵工程。此外,實驗結果也揭示了全局主節點和利用 set2set 模型的重要性,多塔模型也使得 MPNN 更具伸縮性,方便應用於大型圖中。

6.Reference

  1. 《Neural Message Passing for Quantum Chemistry》
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章