GNN圖神經網絡詳述-02

本文作爲第2部分,主要根據原始論文介紹幾篇基礎的工作,主要包括GNN,GCN及變體,DCNN,Tree-LSTM,包括模型的詳解和模型訓練,以及模型評價。

1. The Graph Neural Network Model

這篇論文是第一個提出Graph Neural Network模型的論文,它將神經網絡使用在圖結構數據上,並細述了神經網絡模型了結構組成、計算方法、優化算法、流程實現等等。論文後面還對模型的複雜度進行了評估,以及在現實任務上進行了實驗和比較(比較算法爲NL、L、FNN)。該報告暫時主要關注模型設計部分和實驗結果部分,忽略複雜性評估部分

圖領域應用

對於圖領域問題,假設函數τ\tau是將一個圖GG和圖中的一個節點nn轉化爲一個實值向量的函數
τ(G,n)Rm \tau(G,n)\in{R^m} 那麼監督學習的任務就在於從已知樣本中學習得到這樣的函數。

圖領域的應用主要可以分爲兩種類型:專注於圖的應用(graph-focused)和專注於節點的應用(node-focused)。對於graph-focused的應用,函數τ\tau和具體的節點無關,(即τ(G)\tau(G)),訓練時,在一個圖的數據集中進行分類或迴歸。對於node-focused的應用,τ\tau函數依賴於具體的節點nn,即τ(G,n)\tau(G,n),如下:
在這裏插入圖片描述

  • (a) 是一個化學分子結構,能夠使用圖 GG 進行表示,函數τ(G)\tau(G)可能用於估計這種化學分子對人體有害的概率,因此,我們並不關注分子中具體的原子(相當於節點),所以屬於graph-focused應用。
  • (b) 是一張城堡的圖片,圖片中的每一種結構都由節點表示,函數τ(G,n)\tau(G,n)可能用於預測每一個節點是否屬於城堡(圖中的黑點)。這種類型屬於node-focused應用。

GNN模型詳述

GNN模型基於信息傳播機制,每一個節點通過相互交換信息來更新自己的節點狀態,直到達到某一個穩定值,GNN的輸出就是在每個節點處,根據當前節點狀態分別計算輸出。

有如下定義:

  • 一個圖GG表示爲一對(N,E)(\boldsymbol{N}, \boldsymbol{E}),其中,N\boldsymbol{N}表示節點集合,E\boldsymbol{E}表示邊集。

  • ne[n]ne[n]表示節點nn的鄰居節點集合

  • co[n]co[n]表示以nn節點爲頂點的所有邊集合

  • lnRlN\boldsymbol{l}_{n} \in \mathbb{R}^{l_{N}}表示節點nn的特徵向量

  • l(n1,n2)RlE\boldsymbol{l}_{\left(n_{1}, n_{2}\right)} \in \mathbb{R}^{l_{E}}表示邊(n1,n2)(n_1,n_2)的特徵向量

  • l\boldsymbol{l}表示所有特徵向量疊在一起的向量

:原論文裏面l\boldsymbol{l}表示label,但論文中的label指的是features of objects related to nodes and features of the relationships between the objects,也就是相關特徵,所以這裏一律使用特徵向量翻譯。

論文將圖分爲positional graph和nonpositional graph,對於positional graph,對於每一個節點nn,都會給該節點的鄰居節點uu賦予一個position值νn(u)\nu_{n}(u),該函數稱爲injective function,νn:ne[n]{1,N}\nu_{n} : \mathrm{ne}[n] \rightarrow\{1, \ldots|\mathbf{N}|\}

假設存在一個圖-節點對的集合D=G×N\mathcal{D}=\mathcal{G} \times \mathcal{N}G\mathcal{G}表示圖的集合,N\mathcal{N}表示節點集合,圖領域問題可以表示成一個有如下數據集的監督學習框架
L={(Gi,ni,j,ti,j)Gi=(Ni,Ei)G;ni,jNi;ti,jRm,1ip,1jqi} \mathcal{L}=\left\{\left(\boldsymbol{G}_{i}, n_{i, j}, \boldsymbol{t}_{i, j}\right)| \boldsymbol{G}_{i}=\left(\boldsymbol{N}_{i}, \boldsymbol{E}_{i}\right) \in \mathcal{G}\right.;n_{i, j} \in \boldsymbol{N}_{i} ; \boldsymbol{t}_{i, j} \in \mathbb{R}^{m}, 1 \leq i \leq p, 1 \leq j \leq q_{i} \} 其中,ni,jNin_{i, j} \in \boldsymbol{N}_{i}表示集合NiN\boldsymbol{N}_{i} \in \mathcal{N}中的第jj個節點,ti,j\boldsymbol{t}_{i, j}表示節點nijn_{ij}的期望目標(即標籤)。

節點nn的狀態用xnRs\boldsymbol{x}_{n} \in \mathbb{R}^{s}表示,該節點的輸出用on\boldsymbol{o}_{\boldsymbol{n}}表示,fwf_{\boldsymbol{w}}local transition functiongwg_{\boldsymbol{w}}local output function,那麼xn\boldsymbol{x}_{n}on\boldsymbol{o}_{\boldsymbol{n}}的更新方式如下
xn=fw(ln,lco[n],xne[n],lne[n])on=gw(xn,ln) \begin{array}{l}{\boldsymbol{x}_{n}=f_{\boldsymbol{w}}\left(\boldsymbol{l}_{n}, \boldsymbol{l}_{\mathrm{co}[n]}, \boldsymbol{x}_{\mathrm{ne}[n]}, \boldsymbol{l}_{\mathrm{ne}\left[n\right]}\right)} \\ {\boldsymbol{o}_{n}=g_{\boldsymbol{w}}\left(\boldsymbol{x}_{n}, \boldsymbol{l}_{n}\right)}\end{array} 其中,ln,lco[n],xne[n],lne[n]\boldsymbol{l}_{n}, \boldsymbol{l}_{\operatorname{co}[n]}, \boldsymbol{x}_{\mathrm{ne}[n]}, \boldsymbol{l}_{\mathrm{ne}[n]}分別表示節點nn的特徵向量、與節點nn相連的邊的特徵向量、節點nn鄰居節點的狀態向量、節點nn鄰居節點的特徵向量。

假設x,o,l,lN\boldsymbol{x}, \boldsymbol{o}, \boldsymbol{l}, \boldsymbol{l}_{N}分別爲所有的狀態、所有的輸出、所有的特徵向量、所有節點的特徵向量的疊加起來的向量,那麼上面函數可以寫成如下形式
x=Fw(x,l)o=Gw(x,lN) \begin{array}{l}{\boldsymbol{x}=F_{\boldsymbol{w}}(\boldsymbol{x}, \boldsymbol{l})} \\ {\boldsymbol{o}=\boldsymbol{G}_{\boldsymbol{w}}\left(\boldsymbol{x}, \boldsymbol{l}_{\boldsymbol{N}}\right)}\end{array} 其中,FwF_{\boldsymbol{w}}global transition functionGwG_{\boldsymbol{w}}global output function,分別是fwf_{\boldsymbol{w}}gwg_{\boldsymbol{w}}的疊加形式。

根據Banach的不動點理論,假設FwF_{\boldsymbol{w}}是一個壓縮映射函數,那麼式子有唯一不動點解,而且可以通過迭代方式逼近該不動點
x(t+1)=Fw(x(t),l) \boldsymbol{x}(t+1)=F_{\boldsymbol{w}}(\boldsymbol{x}(t), \boldsymbol{l}) 其中,x(t)\boldsymbol{x}(t)表示x\boldsymbol{x}在第tt個迭代時刻的值,對於任意初值,迭代的誤差是以指數速度減小的,使用迭代的形式寫出狀態和輸出的更新表達式爲
xn(t+1)=fw(ln,lco[n],xne[n](t),lne[n])on(t)=gw(xn(t),ln),nN \begin{aligned} \boldsymbol{x}_{n}(t+1) &=f_{\boldsymbol{w}}\left(\boldsymbol{l}_{n}, \boldsymbol{l}_{\mathrm{co}[n]}, \boldsymbol{x}_{\mathrm{ne}[n]}(t), \boldsymbol{l}_{\mathrm{ne}[n]}\right) \\ \boldsymbol{o}_{n}(t) &=g_{\boldsymbol{w}}\left(\boldsymbol{x}_{n}(t), \boldsymbol{l}_{n}\right), \quad n \in \boldsymbol{N} \end{aligned} GNN的信息傳播流圖以及等效的網絡結構如下圖所示

在這裏插入圖片描述根據上圖所示,頂端的圖是原始的Graph,中間的圖表示狀態向量和輸出向量的計算流圖,最下面的圖表示將更新流程迭代TT次,並展開之後得到等效網絡圖。

網絡的學習算法設計

GNN的學習就是估計參數w\boldsymbol{w},使得函數φw\varphi_{\boldsymbol{w}}能夠近似估計訓練集
L={(Gi,ni,j,ti,j)Gi=(Ni,Ei)G;ni,jNi;ti,jRm,1ip,1jqi} \mathcal{L}=\left\{\left(\boldsymbol{G}_{i}, n_{i, j}, \boldsymbol{t}_{i, j}\right)| \boldsymbol{G}_{i}=\left(\boldsymbol{N}_{i}, \boldsymbol{E}_{i}\right) \in \mathcal{G}\right.;n_{i, j} \in \boldsymbol{N}_{i} ; \boldsymbol{t}_{i, j} \in \mathbb{R}^{m}, 1 \leq i \leq p, 1 \leq j \leq q_{i} \} 其中,qiq_i表示在圖GiG_{i}中監督學習的節點個數,對於graph-focused的任務,需要增加一個特殊的節點,該節點用來作爲目標節點,這樣,graph-focused任務和node-focused任務都能統一到節點預測任務上,學習目標可以是最小化如下二次損失函數
ew=i=1pj=1qi(ti,jφw(Gi,ni,j))2 e_{\boldsymbol{w}}=\sum_{i=1}^{p} \sum_{j=1}^{q_{i}}\left(\boldsymbol{t}_{i, j}-\varphi_{\boldsymbol{w}}\left(\boldsymbol{G}_{i}, n_{i, j}\right)\right)^{2} 優化算法基於隨機梯度下降的策略,優化步驟按照如下幾步進行

  • 按照迭代方程迭代TT次得到xn(t)x_{n}(t),此時接近不動點解:x(T)x\boldsymbol{x}(T) \approx \boldsymbol{x}
  • 計算參數權重的梯度ew(T)/w\partial e_{\boldsymbol{w}}(T) / \partial \boldsymbol{w}
  • 使用該梯度來更新權重w\boldsymbol{w}

這裏假設函數FwF_{\boldsymbol{w}}是壓縮映射函數,保證最終能夠收斂到不動點。另外,這裏的梯度的計算使用backpropagation-through-time algorithm

爲了表明前面的方法是可行的,論文接着證明了兩個結論


理論1(可微性):令FwF_{\boldsymbol{w}}GwG_{\boldsymbol{w}}分別是global transition function和global output function,如果Fw(x,l)F_{\boldsymbol{w}}(\boldsymbol{x}, \boldsymbol{l})Gw(x,lN)G_{\boldsymbol{w}}\left(\boldsymbol{x}, \boldsymbol{l}_{\boldsymbol{N}}\right)對於x\boldsymbol{x}w\boldsymbol{w}是連續可微的,那麼φw\varphi_{\boldsymbol{w}}w\boldsymbol{w}也是連續可微的。

理論2(反向傳播):令FwF_{\boldsymbol{w}}GwG_{\boldsymbol{w}}分別是global transition function和global output function,如果Fw(x,l)F_{\boldsymbol{w}}(\boldsymbol{x}, \boldsymbol{l})Gw(x,lN)G_{\boldsymbol{w}}\left(\boldsymbol{x}, \boldsymbol{l}_{\boldsymbol{N}}\right)對於x\boldsymbol{x}w\boldsymbol{w}是連續可微的。令z(t)\boldsymbol{z}(t)定義爲
z(t)=z(t+1)Fwx(x,l)+ewoGwx(x,lN) z(t)=z(t+1) \cdot \frac{\partial F_{w}}{\partial x}(x, l)+\frac{\partial e_{w}}{\partial o} \cdot \frac{\partial G_{w}}{\partial x}\left(x, l_{N}\right)
那麼,序列z(T),z(T1),\boldsymbol{z}(T), \boldsymbol{z}(T-1), \ldots收斂到一個向量,z=limtz(t)z=\lim _{t \rightarrow-\infty} z(t),並且收斂速度爲指數級收斂以及與初值z(T)\boldsymbol{z}(T)無關,另外,還存在
eww=ewoGww(x,lN)+zFww(x,l) \frac{\partial e_{w}}{\partial \boldsymbol{w}}=\frac{\partial e_{\boldsymbol{w}}}{\partial \boldsymbol{o}} \cdot \frac{\partial G_{\boldsymbol{w}}}{\partial \boldsymbol{w}}\left(\boldsymbol{x}, \boldsymbol{l}_{N}\right)+\boldsymbol{z} \cdot \frac{\partial F_{\boldsymbol{w}}}{\partial \boldsymbol{w}}(\boldsymbol{x}, \boldsymbol{l})
其中,x\boldsymbol{x}是GNN的穩定狀態。

算法流程如下
在這裏插入圖片描述
FORWARD用於迭代計算出收斂點,BACKWARD用於計算梯度。

Transition和Output函數實現

在GNN中,函數gwg_{\boldsymbol{w}}不需要滿足特定的約束,直接使用多層前饋神經網絡,對於函數fwf_{\boldsymbol{w}},則需要着重考慮,因爲fwf_{\boldsymbol{w}}需要滿足壓縮映射的條件,而且與不動點計算相關。下面提出兩種神經網絡和不同的策略來滿足這些需求

  1. Linear(nonpositional) GNN

    對於節點nn狀態的計算,將fwf_{\boldsymbol{w}}改成如下形式
    xn=u ne n]hw(ln,l(n,u),xu,lu),nN \boldsymbol{x}_{n}=\sum_{u \in \text { ne } | n ]} h_{\boldsymbol{w}}\left(\boldsymbol{l}_{n}, \boldsymbol{l}_{(n, u)}, \boldsymbol{x}_{u}, \boldsymbol{l}_{u}\right), \quad n \in \boldsymbol{N} 相當於是對節點nn的每一個鄰居節點使用hwh_{\boldsymbol{w}},並將得到的值求和來作爲節點nn的狀態。

    由此,對上式中的函數hwh_{\boldsymbol{w}}按照如下方式實現
    hw(ln,l(n,a),xu,lu)=An,uxu+bn h_{\boldsymbol{w}}\left(\boldsymbol{l}_{n}, \boldsymbol{l}_{(n, \mathfrak{a})}, \boldsymbol{x}_{u}, \boldsymbol{l}_{u}\right) = \boldsymbol{A}_{n, u} \boldsymbol{x}_{u}+\boldsymbol{b}_{n} 其中,向量bnRs\boldsymbol{b}_{n} \in \mathbb{R}^{s},矩陣An,uRs×s\boldsymbol{A}_{n, u} \in \mathbb{R}^{s \times s}定義爲兩個前向神經網絡的輸出。更確切地說,令產生矩陣An,u\boldsymbol{A}_{n, u}的網絡爲transition network,產生向量bn\boldsymbol{b}_{n}的網絡爲forcing network

    transition network表示爲ϕw\phi_{\boldsymbol{w}}
    ϕw:R2lN+lERs2 \phi_{\boldsymbol{w}} : \mathbb{R}^{2 l_{N}+l_{E}} \rightarrow \mathbb{R}^{s^{2}}
    forcing network表示爲ρw\rho_{\boldsymbol{w}}
    ρw:RlNRs \rho_{\boldsymbol{w}} : \mathbb{R}^{l_{N}} \rightarrow \mathbb{R}^{s} 由此,可以定義An,u\boldsymbol{A}_{n, u}bn\boldsymbol{b}_{n}
    An,u=μsne[u]Ξbw=ρw(ln) \begin{aligned} \boldsymbol{A}_{\boldsymbol{n}, \boldsymbol{u}} &=\frac{\mu}{s|\operatorname{ne}[u]|} \cdot \boldsymbol{\Xi} \\ \boldsymbol{b}_{\boldsymbol{w}} &=\rho_{\boldsymbol{w}}\left(\boldsymbol{l}_{n}\right) \end{aligned} 其中,μ(0,1)\mu \in(0,1)Ξ=resize(ϕw(ln,l(n,u),lu))\Xi=\operatorname{resize}\left(\phi_{\boldsymbol{w}}\left(\boldsymbol{l}_{n}, \boldsymbol{l}_{(n, u)}, \boldsymbol{l}_{u}\right)\right)resize()\text{resize}(\cdot)表示將s2s^2維的向量整理(reshape)成s×ss\times{s}的矩陣,也就是說,將transition network的輸出整理成方形矩陣,然後乘以一個係數就得到An,u\boldsymbol{A}_{n, u}bn\boldsymbol{b}_{n}就是forcing network的輸出。

    在這裏,假定ϕw(ln,l(n,u),lu)1s\left\|\phi_{\boldsymbol{w}}\left(\boldsymbol{l}_{n}, \boldsymbol{l}_{(\boldsymbol{n}, \boldsymbol{u})}, \boldsymbol{l}_{u}\right)\right\|_{1} \leq \boldsymbol{s},這個可以通過設定transition function的激活函數來滿足,比如設定激活函數爲tanh()tanh()。在這種情況下,Fw(x,l)=Ax+bF_{\boldsymbol{w}}(\boldsymbol{x}, \boldsymbol{l})=\boldsymbol{A} \boldsymbol{x}+\boldsymbol{b}A\boldsymbol{A}b\boldsymbol{b}分別是An,u\boldsymbol{A}_{n, u}的塊矩陣形式和bn\boldsymbol{b}_{n}的堆疊形式,通過簡單的代數運算可得
    Fwx1=A1maxuN(nne[u]An,u1)maxuN(μsne[u]nne[u]Ξ1)μ \begin{aligned}\left\|\frac{\partial F_{\boldsymbol{w}}}{\partial \boldsymbol{x}}\right\|_{1} &=\|\boldsymbol{A}\|_{1} \leq \max _{u \in \boldsymbol{N}}\left(\sum_{n \in \operatorname{ne}[u]}\left\|\boldsymbol{A}_{n, u}\right\|_{1}\right) \\ & \leq \max _{u \in N}\left(\frac{\mu}{s|\operatorname{ne}[u]|} \cdot \sum_{n \in \mathrm{ne}[u]}\|\mathbf{\Xi}\|_{1}\right) \leq \mu \end{aligned}
    該式表示FwF_{\boldsymbol{w}}對於任意的參數w\boldsymbol{w}是一個壓縮映射。

    矩陣MM的1-norm定義爲
    M1=maxjimi,j \|M\|_{1}=\max _{j} \sum_{i}\left|m_{i, j}\right|

  2. Nonelinear(nonpositional) GNN:在這個結構中,hwh_{\boldsymbol{w}}通過多層前饋網絡實現,但是,並不是所有的參數w\boldsymbol{w}都會被使用,因爲同樣需要保證FwF_{\boldsymbol{w}}是一個壓縮映射函數,這個可以通過懲罰項來實現
    ew=i=1pj=1qi(ti,jφw(Gi,ni,j))2+βL(Fwx) e_{\boldsymbol{w}}=\sum_{i=1}^{p} \sum_{j=1}^{q_{i}}\left(\boldsymbol{t}_{i, j}-\varphi_{\boldsymbol{w}}\left(\boldsymbol{G}_{i}, n_{i, j}\right)\right)^{2}+\beta L\left(\left\|\frac{\partial F_{\boldsymbol{w}}}{\partial \boldsymbol{x}}\right\|\right) 其中,懲罰項L(y)L(y)y>μy>\mu時爲(yμ)2(y-\mu)^2,在yμy\le{\mu}時爲0,參數μ(0,1)\mu\in(0,1)定義爲希望的FwF_{\boldsymbol{w}}的壓縮係數。

實驗結果

論文將GNN模型在三個任務上進行了實驗:子圖匹配(subgraph matching)任務,誘變(mutagenesis)任務和網頁排序(web page ranking)任務。在這些任務上使用linear和nonlinear的模型測試,其中nonlinear模型中的激活函數使用sigmoid函數。

子圖匹配任務爲在一個大圖GG上找到給定的子圖SS(標記出屬於子圖的節點),也就是說,函數τ\tau必須學習到,如果ni,jn_{i,j}屬於子圖GG,那麼τ(Gi,ni,j)=1\tau(G_i,n_{i,j})=1,否則,τ(Gi,ni,j)=1\tau(G_i,n_{i,j})=-1。實驗結果中,nonlinear模型的效果要好於linear模型的效果,兩個模型都要比FNN模型效果更好。

誘變問題任務是對化學分子進行分類,識別出誘變化合物,採用二分類方法。實驗結果是nonlinear效果較好,但不是最好。

網頁排序任務是學會網頁排序。實驗表明雖然訓練集只包含50個網頁,但是仍然沒有產生過擬合的現象。

模型實現

在模擬的節點分類任務上實現該論文的GNN模型。

  • 任務要求爲輸入一個graph,該graph的所有節點都有標籤,然後對部分節點進行訓練,在驗證階段使用另外一部分節點進行驗證。輸入的數據圖如下圖:

在這裏插入圖片描述
其中,該graph總共有18個節點,分別是{n1,n2,...,n18}\{n1,n2,...,n18\},不同顏色的節點表示不同的節點類別。模擬的問題中節點類別有三類,分別用{0,1,2}\{0,1,2\}表示,在訓練階段,使用節點{n1,n2,n3,n7,n8,n9,n13,n14,n15}\{n1,n2,n3,n7,n8,n9,n13,n14,n15\}進行訓練,相當於每一類取出三個節點訓練,其餘的節點用於在驗證階段進行驗證。

輸入的數據爲(node,label)列表和(node1, node2)邊列表,表示如下

# (node, label)集
N = [("n{}".format(i), 0) for i in range(1,7)] + \
    [("n{}".format(i), 1) for i in range(7,13)] + \
    [("n{}".format(i), 2) for i in range(13,19)]
# 邊集
E = [("n1","n2"), ("n1","n3"), ("n1","n5"),
     ("n2","n4"),
     ("n3","n6"), ("n3","n9"),
     ("n4","n5"), ("n4","n6"), ("n4","n8"),
     ("n5","n14"),
     ("n7","n8"), ("n7","n9"), ("n7","n11"),
     ("n8","n10"), ("n8","n11"), ("n8", "n12"),
     ("n9","n10"), ("n9","n14"),
     ("n10","n12"),
     ("n11","n18"),
     ("n13","n15"), ("n13","n16"), ("n13","n18"),
     ("n14","n16"), ("n14","n18"),
     ("n15","n16"), ("n15","n18"),
     ("n17","n18")]

NN爲節點集合,EE爲邊集合。

  • 模型部分使用論文的linear函數來設計fwf_wgwg_w,而且,這兩個函數在graph所有的節點上進行共享。模型部分實現了Ξ\Xiρ\rho函數,以及完整的forward傳播部分,如下簡化代碼:

    # 實現Xi函數,輸入一個batch的相鄰節點特徵向量對ln,返回是s*s的A矩陣
    # ln是特徵向量維度,s爲狀態向量維度
    # Input : (N, 2*ln)
    # Output : (N, S, S)
    class Xi(nn.Module):
        def __init__(self, ln, s):
            ...
        def forward(self, X):
            ...
    
    # 實現Rou函數
    # Input : (N, ln)
    # Output : (N, S)
    class Rou(nn.Module):
        def __init__(self, ln, s):
            ...
        def forward(self, X):
            ...
    
    # 實現Hw函數
    # Input : (N, 2 * ln) 
    #         每一行都是一個節點特徵向量和該節點的某一個鄰接向量concat
    #         得到的向量
    # Input : (N, s)
    #         對應中心節點的狀態向量
    # Input : (N, )
    #         對應中心節點的度的向量
    # Output : (N, s)
    class Hw(nn.Module):
        def __init__(self, ln, s, mu=0.9):
            ...
        def forward(self, X, H, dg_list):
            ...
    
    class AggrSum(nn.Module):
        def __init__(self, node_num):
            ...
        
        def forward(self, H, X_node):
            ...
    
    # 實現GNN模型
    class OriLinearGNN(nn.Module):
        def __init__(self, node_num, feat_dim, stat_dim, T):
            ...
        # Input : 
        #    X_Node : (N, )
        #    X_Neis : (N, )
        #    H      : (N, s)
        #    dg_list: (N, )
        def forward(self, X_Node, X_Neis, dg_list):
            ...
            for t in range(self.T):
                # (V, s) -> (N, s)
                H = torch.index_select(self.node_states, 0, X_Node)
                # (N, s) -> (N, s)
                H = self.Hw(X, H, dg_list)
                # (N, s) -> (V, s)
                self.node_states = self.Aggr(H, X_Node)
    #             print(H[1])
            ...
    

    可以看出,在模型訓練階段,每次forward,都會直接循環計算T次fwf_w函數計算不動點,然後再計算output。

  • 模型訓練部分按照常規的分類模型進行訓練,採用Adam優化器,學習率保持爲0.01,權重衰減爲0.01,使用交叉熵作爲損失函數,模型訓練部分代碼如下

    # 用於計算accuracy
    def CalAccuracy(output, label):
        ...
    
    # 開始訓練模型
    def train(node_list, edge_list, label_list, T, ndict_path="./node_dict.json"):
        # 生成node-index字典
        ...
    
        # 現在需要生成兩個向量
        # 第一個向量類似於
        #   [0, 0, 0, 1, 1, ..., 18, 18]
        # 其中的值表示節點的索引,連續相同索引的個數爲該節點的度
        # 第二個向量類似於
        #   [1, 2, 4, 1, 4, ..., 11, 13]
        # 與第一個向量一一對應,表示第一個向量節點的鄰居節點
    
        # 首先統計得到節點的度
        ...
        
        # 然後生成兩個向量
        ...
        # 生成度向量
        ...
        # 準備訓練集和測試集
        train_node_list = [0,1,2,6,7,8,12,13,14]
        train_node_label = [0,0,0,1,1,1,2,2,2]
        test_node_list = [3,4,5,9,10,11,15,16,17]
        test_node_label = [0,0,0,1,1,1,2,2,2]
        
        # 開始訓練
        model = OriLinearGNN(node_num=len(node_list),
                             feat_dim=2,
                             stat_dim=2,
                             T=T)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.01)
        criterion = nn.CrossEntropyLoss(size_average=True)
        
        min_loss = float('inf')
        node_inds_tensor = Variable(torch.Tensor(node_inds).long())
        node_neis_tensor = Variable(torch.Tensor(node_neis).long())
        train_label = Variable(torch.Tensor(train_node_label).long())
        for ep in range(500):
            # 運行模型得到結果
            res = model(node_inds_tensor, node_neis_tensor, dg_list) # (V, 3)
            train_res = torch.index_select(res, 0, torch.Tensor(train_node_list).long())
            test_res = torch.index_select(res, 0, torch.Tensor(test_node_list).long())
            loss = criterion(input=train_res,
                             target=train_label)
            loss_val = loss.item()
            train_acc = CalAccuracy(train_res.cpu().detach().numpy(), np.array(train_node_label))
            test_acc = CalAccuracy(test_res.cpu().detach().numpy(), np.array(test_node_label))
            # 更新梯度
            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()
    
            if loss_val < min_loss:
                min_loss = loss_val
            print("==> [Epoch {}] : loss {:.4f}, min_loss {:.4f}, train_acc {:.3f}, test_acc {:.3f}".format(ep, loss_val, min_loss, train_acc, test_acc))
    
  • 模型的訓練和評估結果如下圖:
    在這裏插入圖片描述
    第一條曲線爲訓練loss曲線,第二條曲線爲訓練的acc曲線,第三條曲線爲評估的acc曲線,可以看出,訓練的loss很快達到了最低0.56左右,而準確率達到了1.0,基本已經過擬合,而驗證集的準確率一直很低,最高在第500個epoch處上升到0.667,約爲23\frac{2}{3}左右。

2. Graph Convolutional Networks

圖卷積的演變

按照圖傅里葉變換的性質,可以得到如下圖卷積的定義
(fh)G=Φdiag[h^(λ1),,h^(λn)]ΦTf (\boldsymbol{f} * \boldsymbol{h})_{\mathcal{G}}=\boldsymbol{\Phi} \operatorname{diag}\left[\hat{h}\left(\lambda_{1}\right), \ldots, \hat{h}\left(\lambda_{n}\right)\right] \mathbf{\Phi}^{T} \boldsymbol{f} 其中

  • 對於圖$ \boldsymbol{f}的傅里葉變換爲\boldsymbol{\hat{f}}=\mathbf{\Phi}^{T} \boldsymbol{f}$

  • 對於卷積核的圖傅里葉變換:h^=(h^1,,h^n)\hat{h}=\left(\hat{h}_{1}, \ldots, \hat{h}_{n}\right),其中
    h^k=h,ϕk,k=1,2,n \hat{h}_{k}=\left\langle h, \phi_{k}\right\rangle, k=1,2 \ldots, n 按照矩陣形式就是h^=ΦTh\hat{\boldsymbol{h}}=\mathbf{\Phi}^{T} \boldsymbol{h}

  • 對兩者的傅里葉變換向量f^RN×1\hat{f} \in \mathbb{R}^{N \times 1}h^RN×1\hat{h} \in \mathbb{R}^{N \times 1}求element-wise乘積,等價於將h\boldsymbol{h}組織成對角矩陣,即diag[h^(λk)]RN×N\operatorname{diag}\left[\hat{h}\left(\lambda_{k}\right)\right] \in \mathbb{R}^{N \times N},然後再求diag[h^(λk)]\operatorname{diag}\left[\hat{h}\left(\lambda_{k}\right)\right]f\boldsymbol{f}矩陣乘法。

  • 求上述結果的傅里葉逆變換,即左乘Φ\mathbf{\Phi}

深度學習中的卷積就是要設計trainable的卷積核,從上式可以看出,就是要設計diag[h^(λ1),,h^(λn)]\operatorname{diag}\left[\hat{h}\left(\lambda_{1}\right), \ldots, \hat{h}\left(\lambda_{n}\right)\right],由此,可以直接將其變爲卷積核diag[θ1,,θn]\operatorname{diag}\left[\theta_{1}, \ldots, \theta_{n}\right],而不需要再將卷積核進行傅里葉變換,由此,相當於直接將變換後的參量進行學習。

第一代GCN

第一代GCN爲
youtput=σ(ΦgθΦTx)=σ(Φdiag[θ1,,θn]ΦTx) \boldsymbol{y}_{\text {output}}=\sigma\left(\mathbf{\Phi} \boldsymbol{g}_{\theta} \mathbf{\Phi}^{T} \boldsymbol{x}\right)=\sigma\left(\boldsymbol{\Phi} \operatorname{diag}\left[\theta_{1}, \ldots, \theta_{n}\right] \mathbf{\Phi}^{T} \boldsymbol{x}\right)
其中,x\boldsymbol{x}就是graph上對應每個節點的feature構成的向量,x=(x1,x2,,xn)x=\left(x_{1}, x_{2}, \ldots, x_{n}\right),這裏暫時對每個節點都使用標量,然後經過激活之後,得到輸出youtput\boldsymbol{y}_{\text {output}},之後傳入下一層。

一些缺點:

  • 需要對拉普拉斯矩陣進行譜分解來求Φ\mathbf{\Phi},在graph很大的時候複雜度很高。另外,還需要計算矩陣乘積,複雜度爲O(n2)O(n^2)
  • 卷積核參數爲nn,當graph很大的時候,nn會很大。
  • 卷積核的spatial localization不好。

第二代GCN

圖傅里葉變換是關於特徵值(相當於普通傅里葉變換的頻率)的函數,也就是F(λ1),,F(λn)F\left(\lambda_{1}\right), \ldots, F\left(\lambda_{n}\right),即F(Λ)F(\mathbf{\Lambda}),因此,將卷積核gθ\boldsymbol{g}_{\theta}寫成gθ(Λ)\boldsymbol{g}_{\theta}(\Lambda),然後,將gθ(Λ)\boldsymbol{g}_{\theta}(\Lambda)定義爲如下k階多項式
gθ(Λ)k=0KθkΛk g_{\theta^{\prime}}(\mathbf{\Lambda}) \approx \sum_{k=0}^{K} \theta_{k}^{\prime} \mathbf{\Lambda}^{k} 將卷積公式帶入,可以得到
gθxΦk=0KθkΛkΦTx=k=0Kθk(ΦΛkΦT)x=k=0Kθk(ΦΛΦT)kx=k=0KθkLkx \begin{aligned} g_{\theta^{\prime}} * x & \approx \Phi \sum_{k=0}^{K} \theta_{k}^{\prime} \mathbf{\Lambda}^{k} \mathbf{\Phi}^{T} \boldsymbol{x} \\ &=\sum_{k=0}^{K} \theta_{k}^{\prime}\left(\mathbf{\Phi} \mathbf{\Lambda}^{k} \mathbf{\Phi}^{T}\right) x \\ &=\sum_{k=0}^{K} \theta_{k}^{\prime}\left(\mathbf{\Phi} \mathbf{\Lambda} \mathbf{\Phi}^{T}\right)^{k} x \\ &=\sum_{k=0}^{K} \theta_{k}^{\prime} \boldsymbol{L}^{k} x \end{aligned}
可以看出,這一代的GCN不需要做特徵分解了,可以直接對Laplacian矩陣做變換,通過事先將Laplacian矩陣求出來,以及Lk\boldsymbol{L}^{k}求出來,前向傳播的時候,就可以直接使用,複雜度爲O(Kn2)O(Kn^2)

對於每一次Laplacian矩陣L\boldsymbol{L}x\mathbf{x}相乘,對於節點nn,相當於從鄰居節點ne[n]ne[n]傳遞一次信息給節點nn,由於連續乘以了kk次Laplacian矩陣,那麼相當於n節點的k-hop之內的節點能夠傳遞信息給nn因此,實際上只利用了節點的K-Localized信息

另外,可以使用切比雪夫展開式來近似Lk\boldsymbol{L}^{k}任何k次多項式都可以使用切比雪夫展開式來近似,由此,引入切比雪夫多項式的KK階截斷獲得Lk\boldsymbol{L}^{k}近似,從而獲得對gθ(Λ)g_{\theta}(\mathbf{\Lambda})的近似
gθ(Λ)k=0KθkTk(Λ~) g_{\theta^{\prime}}(\mathbf{\Lambda}) \approx \sum_{k=0}^{K} \theta_{k}^{\prime} T_{k}(\tilde{\mathbf{\Lambda}})
其中,Λ~=2λmaxΛIn\tilde{\mathbf{\Lambda}}=\frac{2}{\lambda_{\max }} \mathbf{\Lambda}-\boldsymbol{I}_{n}θRK\boldsymbol{\theta}^{\prime} \in \mathbb{R}^{K}爲切比雪夫向量,θk\theta_{k}^{\prime}爲第kk個分量,切比雪夫多項式Tk(x)T_{k}(x)使用遞歸的方式進行定義:Tk(x)=2xTk1(x)Tk2(x)T_{k}(x)=2 x T_{k-1}(x)-T_{k-2}(x),其中,T0(x)=1,T1(x)=xT_{0}(x)=1, T_{1}(x)=x

此時,帶入到卷積公式
gθxΦk=0KθkTk(Λ~)ΦTxk=0Kθk(ΦTk(Λ~)ΦT)x=k=0KθkTk(L~)x \begin{aligned} \boldsymbol{g}_{\boldsymbol{\theta}^{\prime}} * \boldsymbol{x} & \approx \mathbf{\Phi} \sum_{k=0}^{K} \theta_{k}^{\prime} T_{k}(\tilde{\boldsymbol{\Lambda}}) \mathbf{\Phi}^{T} \boldsymbol{x} \\ &\approx \sum_{k=0}^{K} \theta_{k}^{\prime}\left(\mathbf{\Phi} T_{k}(\tilde{\mathbf{\Lambda}}) \mathbf{\Phi}^{T}\right) x \\ &=\sum_{k=0}^{K} \theta_{k}^{\prime} T_{k}(\tilde{\boldsymbol{L}}) \boldsymbol{x} \end{aligned} 其中,L~=2λmaxLIn\tilde{\boldsymbol{L}}=\frac{2}{\lambda_{\max }} \boldsymbol{L}-\boldsymbol{I}_{n}

因此,可以得到輸出爲
youtput=σ(k=0KθkTk(L~)x) \boldsymbol{y}_{\text {output}}=\sigma\left(\sum_{k=0}^{K} \theta_{k}^{\prime} T_{k}(\tilde{\boldsymbol{L}}) \boldsymbol{x}\right)

第三代GCN

這一代GCN直接取切比雪夫多項式中K=1K=1,此時模型是1階近似

K=1K=1λmax=2\lambda_{\max }=2帶入可以得到
gθxθ0x+θ1(LIn)x=θ0x+θ1(LIn)x=θ0xθ1(D1/2WD1/2)x \begin{aligned} \boldsymbol{g}_{\boldsymbol{\theta}^{\prime}} * \boldsymbol{x} & \approx \boldsymbol{\theta}_{0}^{\prime} \boldsymbol{x}+\theta_{1}^{\prime}\left(\boldsymbol{L}-\boldsymbol{I}_{n}\right) \boldsymbol{x} \\ &=\boldsymbol{\theta}_{0}^{\prime} \boldsymbol{x}+\theta_{1}^{\prime}\left(\boldsymbol{L}-\boldsymbol{I}_{n}\right) \boldsymbol{x} \\ &=\theta_{0}^{\prime} \boldsymbol{x}-\theta_{1}^{\prime}\left(\boldsymbol{D}^{-1 / 2} \boldsymbol{W} \boldsymbol{D}^{-1 / 2}\right) \boldsymbol{x} \end{aligned}
其中,歸一化拉普拉斯矩陣L=D1/2(DW)D1/2=InD1/2WD1/2\boldsymbol{L}=\boldsymbol{D}^{-1 / 2}(\boldsymbol{D}-\boldsymbol{W}) \boldsymbol{D}^{-1 / 2}=\boldsymbol{I}_{n}-\boldsymbol{D}^{-1 / 2} \boldsymbol{W} \boldsymbol{D}^{-1 / 2}爲了進一步簡化,令θ0=θ1\theta_{0}^{\prime}=-\theta_{1}^{\prime},此時只含有一個參數θ\theta
gθx=θ(In+D1/2WD1/2)x g_{\theta^{\prime}} * x=\theta\left(I_{n}+D^{-1 / 2} W D^{-1 / 2}\right) x
由於In+D1/2WD1/2\boldsymbol{I}_{n}+\boldsymbol{D}^{-1 / 2} \boldsymbol{W} \boldsymbol{D}^{-1 / 2}譜半徑[0,2][0,2]太大,使用歸一化的trick
In+D1/2WD1/2D~1/2W~D~1/2 \boldsymbol{I}_{n}+\boldsymbol{D}^{-1 / 2} \boldsymbol{W} \boldsymbol{D}^{-1 / 2} \rightarrow \tilde{\boldsymbol{D}}^{-1 / 2} \tilde{\boldsymbol{W}} \tilde{\boldsymbol{D}}^{-1 / 2}
其中,W~=W+In\tilde{\boldsymbol{W}}=\boldsymbol{W}+\boldsymbol{I}_{n}D~ij=ΣjW~ij\tilde{D}_{i j}=\Sigma_{j} \tilde{W}_{i j}

由此,帶入卷積公式
gθxRn×1=θ(D~1/2W~D~1/2Rn×n)xRn×1 \underbrace{g_{\theta^{\prime}} * x}_{\mathbb{R}^{n \times 1}}=\theta\left(\underbrace{\tilde{D}^{-1 / 2} \tilde{W} \tilde{D}^{-1 / 2}}_{\mathbb{R}^{n \times n}}\right) \underbrace{x}_{\mathbb{R}^{n \times 1}}
如果推廣到多通道,相當於每一個節點的信息是向量
xRN×1XRN×C x \in \mathbb{R}^{N \times 1} \rightarrow X \in \mathbb{R}^{N \times C}
其中,NN是節點數量CC是通道數,或者稱作表示節點的信息維度數X\mathbf{X}是節點的特徵矩陣

相應的卷積核參數變化
θRΘRC×F \theta \in \mathbb{R} \rightarrow \Theta \in \mathbb{R}^{C \times F}
其中,FF爲卷積核數量。

那麼卷積結果寫成矩陣形式爲
ZRN×F=D~1/2W~D~1/2RN×NXRN×CΘRC×F \underbrace{Z}_{\mathbb{R}^{N \times F}}=\underbrace{\tilde{D}^{-1 / 2} \tilde{W} \tilde{D}^{-1 / 2}}_{\mathbb{R}^{N \times N}} \underbrace{X}_{\mathbb{R}^{N \times C}} \underbrace{\mathbf{\Theta}}_{\mathbb{R}^{C \times F}}
上述操作可以疊加多層,對上述輸出激活一下,就可以作爲下一層節點的特徵矩陣

這一代GCN特點:

  • K=1K=1,相當於直接取鄰域信息,類似於3×33\times{3}的卷積核。
  • 由於卷積核寬度減小,可以通過增加捲積層數來擴大感受野,從而增強網絡的表達能力。
  • 增加了參數約束,比如λmax2\lambda_{\max } \approx 2,引入歸一化操作。

論文模型

論文采用兩層的GCN,用來在graph上進行半監督的節點分類任務,鄰接矩陣爲AA,首先計算出A^=D~12A~D~12\hat{A}=\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}},由此,前向網絡模型形式如下
Z=f(X,A)=softmax(A^ReLU(A^XW(0))W(1)) Z=f(X, A)=\operatorname{softmax}\left(\hat{A} \operatorname{ReLU}\left(\hat{A} X W^{(0)}\right) W^{(1)}\right)
其中,W(0)RC×HW^{(0)} \in \mathbb{R}^{C \times H}爲輸入層到隱藏層的權重矩陣,隱藏層的特徵維度爲HHW(1)RH×FW^{(1)} \in \mathbb{R}^{H \times F}爲隱藏層到輸出層的權重矩陣,softmax激活函數定義爲softmax(xi)=1Zexp(xi)\operatorname{softmax}\left(x_{i}\right)=\frac{1}{\mathcal{Z}} \exp \left(x_{i}\right)Z=iexp(xi)\mathcal{Z}=\sum_{i} \exp \left(x_{i}\right),相當於對每一列做softmax,由此,得到交叉熵損失函數
L=lYLf=1FYlflnZlf \mathcal{L}=-\sum_{l \in \mathcal{Y}_{L}} \sum_{f=1}^{F} Y_{l f} \ln Z_{l f}
其中,YL\mathcal{Y}_{L}爲帶有標籤的節點集合。

在這裏插入圖片描述
上圖中左圖爲GCN圖示,輸入爲CC個通道,輸出爲FF個通道,Y1Y_1Y2Y_2爲節點標籤。右圖爲在一數據集上進行訓練得到的隱藏層激活值經過t-SNE降維可視化後的結果,可以看出聚類效果較好。

實驗結果

論文在如下幾個任務中進行實驗

  • 在citation network中進行半監督的document classification。
  • 在從knowledge graph中提取的bipartite graph中進行半監督的entity classification

實驗數據說明如下
在這裏插入圖片描述
前三個Dataset是citation network數據集,節點表示文檔,邊表示引用的連接,label rate表示用來有監督訓練的節點數量佔總節點數量比例,第四個Dataset是bipartite graph數據集。

結果如下:
在這裏插入圖片描述

可以看出,在比較的幾種算法中,論文GCN的在準確率和時間上都最好。

3. DCNN

該模型對每一個節點(或邊、或圖)採用H個hop的矩陣進行表示,每一個hop都表示該鄰近範圍的鄰近信息,由此,對於局部信息的獲取效果比較好,得到的節點的representation的表示能力很強。

DCNN模型詳述

假設有如下定義

  • 一個graph數據集G={Gtt1T}\mathcal{G}=\left\{G_{t} | t \in 1 \ldots T\right\}
  • graph定義爲Gt=(Vt,Et)G_{t}=\left(V_{t}, E_{t}\right),其中,VtV_t爲節點集合,EtE_t爲邊集合
  • 所有節點的特徵矩陣定義爲XtX_t,大小爲Nt×FN_t\times{F},其中,NtN_t爲圖GtG_t的節點個數,FF爲節點特徵維度
  • 邊信息EtE_t定義爲Nt×NtN_t\times{}N_t的鄰接矩陣AtA_t,由此可以計算出節點度(degree)歸一化的轉移概率矩陣PtP_t,表示從ii節點轉移到jj節點的概率。

對於graph來說沒有任何限制,graph可以是帶權重的或不帶權重的,有向的或無向的。

模型的目標爲預測YY,也就是預測每一個圖的節點標籤,或者邊的標籤,或者每一個圖的標籤,在每一種情況中,模型輸入部分帶有標籤的數據集合,然後預測剩下的數據的標籤

DCNN模型輸入圖G\mathcal{G},返回硬分類預測值YY或者條件分佈概率P(YX)\mathbb{P}(Y|X)。該模型將每一個預測的目標對象(節點、邊或圖)轉化爲一個diffusion-convolutional representation,大小爲H×FH\times{}FHH表示擴散的hops。因此,對於節點分類任務,圖tt的confusion-convolutional representation爲大小爲Nt×H×FN_t\times{H}\times{F}的張量,表示爲ZtZ_t,對於圖分類任務,張量ZtZ_t爲大小爲H×FH\times{F}的矩陣,對於邊分類任務,張量ZtZ_t爲大小爲Mt×H×FM_t\times{H}\times{F}的矩陣。示意圖如下

在這裏插入圖片描述
對於節點分類任務,假設PtP_t^*PtP_t的power series,大小爲Nt×H×NtN_t\times{H}\times{N_t},那麼對於圖tt的節點ii,第jj個hop,第kk維特徵值ZtijkZ_{tijk}計算公式爲
Ztijk=f(Wjkcl=1NtPtijlXtlk) Z_{t i j k}=f\left(W_{j k}^{c} \cdot \sum_{l=1}^{N_{t}} P_{t i j l}^{*} X_{t l k}\right)
使用矩陣表示爲
Zt=f(WcPtXt) Z_{t}=f\left(W^{c} \odot P_{t}^{*} X_{t}\right)
其中\odot表示element-wise multiplication,由於模型只考慮HH跳的參數,即參數量爲O(H×F)O(H\times{F})使得diffusion-convolutional representation不受輸入大小的限制

在計算出ZZ之後,過一層全連接得到輸出YY,使用Y^\hat{Y}表示硬分類預測結果,使用P(YX)\mathbb{P}(Y|X)表示預測概率,計算方式如下
Y^=argmax(f(WdZ)) \hat{Y}=\arg \max \left(f\left(W^{d} \odot Z\right)\right)

P(YX)=softmax(f(WdZ)) \mathbb{P}(Y | X)=\operatorname{softmax}\left(f\left(W^{d} \odot Z\right)\right)

對於圖分類任務,直接採用所有節點表示的均值作爲graph的representation
Zt=f(Wc1NtTPtXt/Nt) Z_{t}=f\left(W^{c} \odot 1_{N_{t}}^{T} P_{t}^{*} X_{t} / N_{t}\right)
其中,1Nt1_{N_t}是全爲1的Nt×1N_t\times{1}的向量。

對於邊分類任務,通過將每一條邊轉化爲一個節點來進行訓練和預測,這個節點與原來的邊對應的首尾節點相連,轉化後的圖的鄰接矩陣AtA_t'可以直接從原來的鄰接矩陣AtA_t增加一個incidence matrix得到
At=(AtBtTBt0) A_{t}^{\prime}=\left(\begin{array}{cc}{A_{t}} & {B_{t}^{T}} \\ {B_{t}} & {0}\end{array}\right)
之後,使用AtA_t'來計算PtP_t',並用來替換PtP_t來進行分類。

對於模型訓練,使用梯度下降法,並採用early-stop方式得到最終模型。

實驗結果

節點分類任務的實驗數據集使用Cora和Pubmed數據集,包含scientific papers(相當於node)、citations(相當於edge)和subjects(相當於label)。實驗評估標準使用分類準確率以及F1值。

節點分類的實驗結果如下

在這裏插入圖片描述

可以看出使用各種評估標準,DCNN效果都是最好的。

圖分類任務的實驗結果如下

在這裏插入圖片描述
可以看出,在不同數據集上,DCNN在圖分類任務上並沒有明顯表現出很好的效果。

優缺點

優點

  • 節點分類準確率很高
  • 靈活性
  • 快速

缺點

  • 內存佔用大:DCNN建立在密集的張量計算上,需要存儲大量的張量,需要O(Nt2H)O(N_t^2H)的空間複雜度。
  • 長距離信息傳播不足:模型對於局部的信息獲取較好,但是遠距離的信息傳播不足。

4. Tree-LSTM

序列型的LSTM模型擴展到樹型的LSTM模型,簡稱Tree-LSTM,並根據孩子節點是否有序,論文提出了兩個模型變體,Child-Sum Tree-LSTM模型和N-ary Tree-LSTM模型。和序列型的LSTM模型的主要不同點在於,序列型的LSTM從前一時刻獲取隱藏狀態hth_t,而樹型的LSTM從其所有的孩子節點獲取隱藏狀態。

模型詳解

Tree-LSTM模型對於每一個孩子節點都會產生一個“遺忘門”fjkf_{jk},這個使得模型能夠從所有的孩子節點選擇性地獲取信息和結合信息

Child-Sum Tree-LSTMs

該模型的更新方程如下
h~j=kC(j)hkij=σ(W(i)xj+U(i)h~j+b(i))fjk=σ(W(f)xj+U(f)hk+b(f))oj=σ(W(o)xj+U(o)h~j+b(o))uj=tanh(W(u)xj+U(u)h~j+b(u))cj=ijuj+kC(j)fjkckhj=ojtanh(cj) \begin{aligned} \tilde{h}_{j} &=\sum_{k \in C(j)} h_{k} \\ i_{j} &=\sigma\left(W^{(i)} x_{j}+U^{(i)} \tilde{h}_{j}+b^{(i)}\right) \\ f_{j k} &=\sigma\left(W^{(f)} x_{j}+U^{(f)} h_{k}+b^{(f)}\right) \\ o_{j} &=\sigma\left(W^{(o)} x_{j}+U^{(o)} \tilde{h}_{j}+b^{(o)}\right) \\ u_{j} &=\tanh \left(W^{(u)} x_{j}+U^{(u)} \tilde{h}_{j}+b^{(u)}\right) \\ c_{j} &=i_{j} \odot u_{j}+\sum_{k \in C(j)} f_{j k} \odot c_{k} \\ h_{j} &=o_{j} \odot \tanh \left(c_{j}\right) \end{aligned} 其中,C(j)C(j)表示jj節點的鄰居節點的個數,hkh_k表示節點kk的隱藏狀態,iji_j表示節點jj的”輸入門“,fjkf_{jk}表示節點jj的鄰居節點kk的“遺忘門“,ojo_j表示節點jj的”輸出門“。

這裏的關鍵點在於第三個公式的fjkf_{jk},這個模型對節點jj的每個鄰居節點kk都計算了對應的”遺忘門“向量,然後在第六行中計算cjc_j時對鄰居節點的信息進行”遺忘“和組合。

由於該模型是對所有的孩子節點求和,所以這個模型對於節點順序不敏感的,適合於孩子節點無序的情況。

N-ary Tree-LSTMs

假如一個樹的最大分支數爲NN(即孩子節點最多爲NN個),而且孩子節點是有序的,對於節點jj,對於該節點的第kk個孩子節點的隱藏狀態和記憶單元分別用hjkh_{jk}cjkc_{jk}表示。模型的方程如下
ij=σ(W(i)xj+=1NU(i)hj+b(i))fjk=σ(W(f)xj+=1NUk(f)hj+b(f))oj=σ(W(o)xj+=1NU(o)hj+b(o))uj=tanh(W(u)xj+=1NU(a)hj+b(a))cj=ijuj+=1Nfjcjhj=ojtanh(cj) \begin{aligned} i_{j} &=\sigma\left(W^{(i)} x_{j}+\sum_{\ell=1}^{N} U_{\ell}^{(i)} h_{j \ell}+b^{(i)}\right) \\ f_{j k} &=\sigma\left(W^{(f)} x_{j}+\sum_{\ell=1}^{N} U_{k \ell}^{(f)} h_{j \ell}+b^{(f)}\right) \\ o_{j} &=\sigma\left(W^{(o)} x_{j}+\sum_{\ell=1}^{N} U_{\ell}^{(o)} h_{j \ell}+b^{(o)}\right) \\ u_{j} &=\tanh \left(W^{(u)} x_{j}+\sum_{\ell=1}^{N} U_{\ell}^{(a)} h_{j \ell}+b^{(a)}\right) \\ c_{j} &=i_{j} \odot u_{j}+\sum_{\ell=1}^{N} f_{j \ell} \odot c_{j \ell} \\ h_{j} &=o_{j} \odot \tanh \left(c_{j}\right) \end{aligned} 值得注意的是該模型爲每個孩子節點都單獨地設置了參數UlU_{l}

模型訓練

分類任務

分類任務定義爲在類別集Y\mathcal{Y}中預測出正確的標籤y^\hat{y},對於每一個節點jj,使用一個softmax分類器來預測節點標籤y^j\hat{y}_j,分類器取每個節點的隱藏狀態hjh_j作爲輸入
p^θ(y{x}j)=softmax(W(s)hj+b(s))y^j=argmaxyp^θ(y{x}j) \begin{aligned} \hat{p}_{\theta}\left(y |\{x\}_{j}\right) &=\operatorname{softmax}\left(W^{(s)} h_{j}+b^{(s)}\right) \\ \hat{y}_{j} &=\arg \max _{y} \hat{p}_{\theta}\left(y |\{x\}_{j}\right) \end{aligned} 損失函數使用negative log-likelihood
J(θ)=1mk=1mlogp^θ(y(k){x}(k))+λ2θ22 J(\theta)=-\frac{1}{m} \sum_{k=1}^{m} \log \hat{p}_{\theta}\left(y^{(k)} |\{x\}^{(k)}\right)+\frac{\lambda}{2}\|\theta\|_{2}^{2} 其中,mm是帶有標籤的節點數量,λ\lambdaL2L2是正則化超參。

語義相關性任務

該任務給定一個句子對(sentence pair),模型需要預測出一個範圍在[1,K][1,K]之間的實數值,這個值越高,表示相似度越高。

論文首先對每一個句子產生一個representation,兩個句子的表示分別用hLh_LhRh_R表示,得到這兩個representation之後,從distance和angle兩個方面考慮,使用神經網絡來得到(hL,hR)(h_L,h_R)相似度:
h×=hLhRh+=hLhRhs=σ(W(×)h×+W(+)h++b(h))p^θ=softmax(W(p)hs+b(p))y^=rTp^θ \begin{aligned} h_{ \times} &=h_{L} \odot h_{R} \\ h_{+} &=\left|h_{L}-h_{R}\right| \\ h_{s} &=\sigma\left(W^{( \times)} h_{ \times}+W^{(+)} h_{+}+b^{(h)}\right) \\ \hat{p}_{\theta} &=\operatorname{softmax}\left(W^{(p)} h_{s}+b^{(p)}\right) \\ \hat{y} &=r^{T} \hat{p}_{\theta} \end{aligned} 其中,rT=[12K]r^{T}=\left[\begin{array}{llll}{1} & {2} & {\ldots} & {K}\end{array}\right]。模型期望根據訓練得到的參數θ\theta得到的結果:y^=rTp^θy\hat{y}=r^{T} \hat{p}_{\theta} \approx y。由此,定義一個目標分佈pp
pi={yy,i=y+1yy+1,i=y0 otherwise  p_{i}=\left\{\begin{array}{ll}{y-\lfloor y\rfloor,} & { i=\lfloor y\rfloor+ 1} \\ {\lfloor y\rfloor- y+1,} & {i=\lfloor y\rfloor} \\ { 0} & {\text { otherwise }}\end{array}\right. 其中,1iK1\le{i}\le{K},損失函數爲ppp^θ\hat{p}_{\theta}之間的KL散度:
J(θ)=1mk=1mKL(p(k)p^θ(k))+λ2θ22 J(\theta)=\frac{1}{m} \sum_{k=1}^{m} \mathrm{KL}\left(p^{(k)} \| \hat{p}_{\theta}^{(k)}\right)+\frac{\lambda}{2}\|\theta\|_{2}^{2}

實驗結果

對於分類任務,實驗數據使用Stanford Sentiment Treebank,分爲五類:very negative, negative, neural, positive和very positive。在該數據集上的測試集上的準確度結果如下:

在這裏插入圖片描述

對於語義相似度度量,模型的任務是預測出兩個句子語義的相似度分數。在SICK的語義相似度子任務上的測試結果如下
在這裏插入圖片描述
對於模型的實現部分目前可以直接使用DGL或者PyG等庫幫助快速實現,接下來會整理如何使用這些庫進行簡單的圖網絡的實驗。圖神經網絡在工業界的挑戰主要是非結構化的圖結構計算需要消耗巨大的計算資源和存儲資源。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章