概括
使用GNN訓練除能夠用於多個圖算法的算法執行器, 證明了GNN在圖結構的輸入上的強大表示能力.
將傳統圖算法(BFS, Bell-Ford, Prim)中算法執行的每一步的決策作爲標籤,上一步決策後的圖信息作爲輸入, 預測算法在當前時間步的決策, 網絡最終能夠學習到傳統算法的選擇策略.
文章的動機: 目標是增強GNN的更新規則中的算法可解釋性, 增強這種歸納性的認知。這種認知對於例如發現新穎的算法或改進現有算法是個很重要的前提。文章還通過實驗證明了在學習兩種不同算法(BFS和Bellman-Ford)之間的轉換,這直接指出了一個學到的算法執行器增強另一種算法的預測的準確性的潛力。
問題設定
輸入:
一個T長的graph-structured的圖的序列: 序列中的圖G是固定的,點和邊的信息是變化的
點 x i ⃗ ( t ) ∈ R N x , i ∈ V \vec{x_i}^{(t)} \in R^{N_x}, \quad i \in V x i ( t ) ∈ R N x , i ∈ V
邊 e i j ⃗ ( t ) ∈ R N e \vec{e_{ij}}^{(t)} \in R^{N_e} e i j ( t ) ∈ R N e
待學習的算法A
輸出:
y i ⃗ ( t ) ∈ R N y , i ∈ V \vec{y_i}^{(t)} \in R^{N_y}, \quad i \in V y i ( t ) ∈ R N y , i ∈ V
網絡的組成:
Encoder: f A f_A f A 與A相關, 輸出與節點相關, z i ⃗ ( t ) = f A ( x i ( t ) , h i ⃗ ( t − 1 ) ) \vec{z_i}^{(t)}=f_A(x_i^{(t)}, \vec{h_i}^{(t-1)}) z i ( t ) = f A ( x i ( t ) , h i ( t − 1 ) ) , 其中h i ( 0 ) = 0 ⃗ {h_i}^{(0)} = \vec{0} h i ( 0 ) = 0
Processor: P P P 與A無關, H ( t ) = P ( Z ( t ) , E ( t ) ) H^{(t)}=P(Z^{(t)}, E^{(t)}) H ( t ) = P ( Z ( t ) , E ( t ) )
Decoder: g A g_A g A 與A相關, 輸出與節點相關, y i ⃗ ( t ) = g A ( z i ( t ) , h i ⃗ ( t ) ) \vec{y_i}^{(t)}=g_A(z_i^{(t)}, \vec{h_i}^{(t)}) y i ( t ) = g A ( z i ( t ) , h i ( t ) )
Temination: T A T_A T A 與A相關, τ ( t ) = σ ( T A ( H ( t ) , H ( t ) ‾ ) \tau^{(t)}=\sigma(T_A(H^{(t)}, \overline{H^{(t)}}) τ ( t ) = σ ( T A ( H ( t ) , H ( t ) ) , 得到的是決定算法是否要進行下一個時間步的概率, 當τ > 0.5 \tau > 0.5 τ > 0 . 5 時, 算法從Encoder重複.
其中H ( t ) ‾ = 1 ∣ V ∣ ∑ i ∈ V h i ⃗ ( t ) \overline{H^{(t)}} = \frac{1}{|V|} \sum_{i\in V}\vec{h_i}^{(t)} H ( t ) = ∣ V ∣ 1 ∑ i ∈ V h i ( t )
Tips:
x i ⃗ t + 1 \vec{x_i}^{t+1} x i t + 1 會部分或完全引入y i ⃗ t \vec{y_i}^{t} y i t 的信息,即從上一個時間步的決策作爲下一步的輸入信息.
輸出中可以很容易的輸出和邊相關的結果以及和圖相關的結果, 文章中並沒有用到,所以輸出只輸出和節點相關的內容.
h i h_i h i 的計算可以通過GNN層獲得, 使用到用於比較的GNN有GAT, MPNNS及變種
實驗設定
圖類型:
7種
每個點, 新加self-edge: 利於信息傳遞
對每個邊分配隨機分配權重, 符合[0.2, 1]的平均分佈, 對於所有時間步t, 這個參數作爲邊上的唯一信息.
這種邊權採樣保證了結果恢復的唯一性, 簡化了對比分析
爲什麼認爲學習算法的執行具有擴展性: 人類專家從小圖觀察規律, 設計算法用於大圖, 因此認爲算法也能夠學到類似的規律
Parallel algo:
BFS 每個點維護bool, 表示是否可達
BF 每個點維護Real, 表示與s的最短距離
文章斷言稱這類算法的需要在鄰域上做離散決策的特點,很適合max-aggregator MPNN來學習執行.
BFS:
x i ( 1 ) = { 1 , i = s 0 , i ≠ s x i ( t + 1 ) = { 1 , x i t = 1 1 , ∃ j . ( j , i ) ∧ x j ( t ) = 1 0 , o t h e r w i s e x i ( t + 1 ) = y i ( t )
x_{i}^{(1)}=
\begin{cases}
\begin{aligned}
1 & , & i=s \\
0 & , & i \neq s
\end{aligned}
\end{cases}
\qquad
x_{i}^{(t+1)}=
\begin{cases}
\begin{aligned}
1 & , & x_i^{t} = 1 \\
1 & , & \exist j.(j,i) \land x_j^{(t)}=1 \\
0 & , & otherwise
\end{aligned}
\end{cases}
\\ \quad
\\
x_i^{(t+1)} = y_i^{(t)}
x i ( 1 ) = { 1 0 , , i = s i = s x i ( t + 1 ) = ⎩ ⎪ ⎪ ⎨ ⎪ ⎪ ⎧ 1 1 0 , , , x i t = 1 ∃ j . ( j , i ) ∧ x j ( t ) = 1 o t h e r w i s e x i ( t + 1 ) = y i ( t )
BF:
x i ( 1 ) = { 1 , i = s + ∞ , i ≠ s x i ( t + 1 ) = m i n ( x i ( t ) , m i n j , i ) x j ( t ) + e j i ( t ) ⃗ ) y i ( t ) = P i ( t ) ∣ ∣ x i ( t + 1 ) (|| means concat)
x_{i}^{(1)}=
\begin{cases}
\begin{aligned}
1 & , & i=s \\
+\infty & , & i \neq s
\end{aligned}
\end{cases}
\qquad
x_{i}^{(t+1)}=min(\vec{x_i^{(t)}, min_{j,i)} x_j^{(t)} + e_{ji}^{(t)}})
\\ \quad
\\
y_i^{(t)}=P_i^{(t)} || x_i^{(t+1)} \text{(|| means concat)}
x i ( 1 ) = { 1 + ∞ , , i = s i = s x i ( t + 1 ) = m i n ( x i ( t ) , m i n j , i ) x j ( t ) + e j i ( t ) ) y i ( t ) = P i ( t ) ∣ ∣ x i ( t + 1 ) (|| means concat)
Sequential algo:
seq的一類算法, 在每個時間步只處理一個節點, 在構造算法中常見
文章希望能夠證明, 圖神經網絡也能很好的學到序列模型的算法
Prim:
x i ( 1 ) = { 1 , i = s 0 , i ≠ s x i ( t + 1 ) = { 1 , x i t = 1 1 , i = a r g m i n j s . t . x j ( t ) = 0 m i n k . s . t . x k ( t ) = 1 e j k ( t ) , ∗ 0 , o t h e r w i s e p i ( t ) = { i , i = s P i ( t − 1 ) , i ≠ s ∧ x i ( t ) = 1 a r g m i n x j ( t ) = 1 e i j ( t ) , x i ( t ) = 0 ∧ x i ( t + 1 ) = 1 , ∗ y i ( t ) = P i ( t ) ∣ ∣ x i ( t + 1 ) (|| means concat)
x_{i}^{(1)}=
\begin{cases}
\begin{aligned}
1 & , & i=s \\
0 & , & i \neq s
\end{aligned}
\end{cases}
\qquad
x_{i}^{(t+1)}=
\begin{cases}
\begin{aligned}
1 & , & x_i^{t} = 1 \\
1 & , & i=argmin_{j s.t. x_j^{(t)}=0} min_{k.s.t. x_k^{(t)}=1} e_{jk}^{(t)},*\\
0 & , & otherwise
\end{aligned}
\end{cases}
\\ \quad
\\
p_i^{(t)}=
\begin{cases}
\begin{aligned}
i & , & i=s\\
P_i^{(t-1)} &, & i\neq s \land x_i^{(t)}=1\\
argmin_{x_j^{(t)}=1} e_{ij}^{(t)} &, & x_i^{(t)}=0 \land x_i^{(t+1)}=1, *
\end{aligned}
\end{cases}
\qquad
y_i^{(t)}=P_i^{(t)} || x_i^{(t+1)} \text{(|| means concat)}
x i ( 1 ) = { 1 0 , , i = s i = s x i ( t + 1 ) = ⎩ ⎪ ⎪ ⎨ ⎪ ⎪ ⎧ 1 1 0 , , , x i t = 1 i = a r g m i n j s . t . x j ( t ) = 0 m i n k . s . t . x k ( t ) = 1 e j k ( t ) , ∗ o t h e r w i s e p i ( t ) = ⎩ ⎪ ⎪ ⎨ ⎪ ⎪ ⎧ i P i ( t − 1 ) a r g m i n x j ( t ) = 1 e i j ( t ) , , , i = s i = s ∧ x i ( t ) = 1 x i ( t ) = 0 ∧ x i ( t + 1 ) = 1 , ∗ y i ( t ) = P i ( t ) ∣ ∣ x i ( t + 1 ) (|| means concat)
網絡結構
在processor的模型選擇上, 除了max-aggregator的MPNN, 文章還設置了GAT, mean-aggregator MPNN, sum-aggregator MPNN. 爲了說明圖神經網絡是必須的, 文章還設置了非圖神經網絡的結構, 由於前面已經有文章說明, MLP不適合變節點樹的情況, 最終引入了一個LSTM的結構作爲對照.
模型使用的是監督學習
bin cross-entropy: reachibility
mean squared error: distance
categorical cross-entropy: predecessor node
bin cross-entropy: termination
Tips
如果在|V|步後模型的termination沒有終止, 默認算法終止
對於seq類的問題,categorical cross-entropy預測的是下一個節點, 爲此需要將未加入MST的節點添加mask
結果分析討論
訓練:使用模型在20個節點的圖上學習
測試方法: 應用到20, 50, 100節點規模測試過程準確性和最終準確性.
結論:
額外指標 : 遷移學習的對比
動機: 圖網絡學習到的是什麼信息?
通過在一類圖上訓練, 在另一類圖上測試, 得出MPNN-max能夠在學習到輸入圖中的結構相似性在不同類的圖中獲得較好的bias.