圖神經網絡13-圖注意力模型GAT網絡詳解

論文摘要

圖卷積發展至今,早期的進展可以歸納爲譜圖方法和非譜圖方法,這兩者都存在一些挑戰性問題。

  • 譜圖方法:學習濾波器主要基於圖的拉普拉斯特徵,圖的拉普拉斯取決於圖結構本身,因此在特定圖結構上學習到的譜圖模型無法直接應用到不同結構的圖中。
  • 非譜圖方法:對不同大小的鄰域結構,像CNNs那樣設計統一的卷積操作比較困難。

此外,圖結構數據往往存在大量噪聲,換句話說,節點之間的連接關係有時並沒有特別重要,節點的不同鄰居的相對重要性也有差異。

本文提出了圖注意力網絡(GAT),利用masked self-attention layer,通過堆疊網絡層,獲取每個節點的鄰域特徵,爲鄰域中的不同節點分配不同的權重。這樣做的好處是不需要高成本的矩陣運算,也不用事先知道圖結構信息。通過這種方式,GAT可以解決譜圖方法存在的問題,同時也能應用於歸納學習和直推學習問題。

GAT模型結構

假設一個圖有N個節點,節點的F維特徵集合可以表示爲\mathbf{h}=\left\{\vec{h}_{1}, \vec{h}_{2}, \ldots, \vec{h}_{N}\right\}, \vec{h}_{i} \in \mathbb{R}^{F}

注意力層的目的是輸出新的節點特徵集合,
{h}^{\prime}=\left\{\vec{h}_{1}^{\prime}, \vec{h}_{2}^{\prime}, \ldots, \vec{h}_{N}^{\prime}\right\}, \vec{h}_{i}^{\prime} \in {R}^{F^{\prime}}

在這個過程中特徵向量的維度可能會改變,即F \rightarrow F^{\prime} 爲了保留足夠的表達能力,將輸入特徵轉化爲高階特徵,至少需要一個可學習的線性變換。例如,對於節點i,j,對它們的特徵\vec{h}_{i},\vec{h}_{j}應用線性變換W\in\mathbb{R}^{F^{'}\times F},從F維轉化爲F^{\prime} 維新特徵爲\vec{h}_{i}^{\prime},\vec{h}_{j}^{\prime}
e_{i j}=a\left({W} \vec{h}_{i}, {W} \vec{h}_{j}\right)

上式在將輸入特徵運用線性變換轉化爲高階特徵後,使用self-attention爲每個節點分配注意力(權重)。其中a表示一個共享注意力機制:\mathbb{R}^{F^{\prime}} \times \mathbb{R}^{F^{\prime}} \rightarrow \mathbb{R},用於計算注意力係數e_{ij},也就是節點i對節點j的影響力系數(標量)。

上面的注意力計算考慮了圖中任意兩個節點,也就是說,圖中每個節點對目標節點的影響都被考慮在內,這樣就損失了圖結構信息。論文中使用了masked attention,對於目標節點i來說,只計算其鄰域內的節點j\in \mathcal{N}對目標節點的相關度e_{ij}(包括自身的影響)。

爲了更好的在不同節點之間分配權重,我們需要將目標節點與所有鄰居計算出來的相關度進行統一的歸一化處理,這裏用softmax歸一化:

\alpha_{i j}=\operatorname{softmax}_{j}\left(e_{i j}\right)=\frac{\exp \left(e_{i j}\right)}{\sum_{k \in \mathcal{N}_{i}} \exp \left(e_{i k}\right)}

關於a的選擇,可以用向量的內積來定義一種無參形式的相關度計算\langle {W} \vec{h}_{i}\ , {W} \vec{h}_{j} \rangle,也可以定義成一種帶參的神經網絡層,只要滿足a:R^{d^{(l+1)}} \times R^{d^{(l+1)}} \rightarrow R,即輸出一個標量值表示二者的相關度即可。在論文實驗中,a是一個單層前饋神經網絡,參數爲權重向量\overrightarrow{\mathrm{a}} \in \mathbb{R}^{2 F^{\prime}},使用負半軸斜率爲0.2的LeakyReLU作爲非線性激活函數:

e_{ij} = \text { LeakyReLU }\left(\overrightarrow{\mathbf{a}}^{T}\left[\mathbf{W} \vec{h}_{i} \Vert \mathbf{W} \vec{h}_{j}\right]\right)

其中\Vert表示拼接操作。完整的權重係數計算公式爲:

\alpha_{i j}=\frac{\exp \left(\text { LeakyReLU }\left(\overrightarrow{{a}}^{T}\left[{W} \vec{h}_{i} \| {W} \vec{h}_{j}\right]\right)\right)}{\sum_{k \in {N}_{i}} \exp \left(\text { LeakyReLU }\left(\overrightarrow{{a}}^{T}\left[{W} \vec{h}_{i} \| {W} \vec{h}_{k}\right]\right)\right)}

得到歸一化注意係數後,計算其對應特徵的線性組合,通過非線性激活函數後,每個節點的最終輸出特徵向量爲:

\vec{h}_{i}^{\prime}=\sigma\left(\sum_{j \in {N}_{i}} \alpha_{i j} {W} \vec{h}_{j}\right)

多頭注意力機制

另外,本文使用多頭注意力機制(multi-head attention)來穩定self-attention的學習過程,即對上式調用K組相互獨立的注意力機制,然後將輸出結果拼接起來:

\vec{h}_{i}^{\prime}=\Vert_{k=1}^{K} \sigma\left(\sum_{j \in {N}_{i}} \alpha_{i j}^{k} {W}^{k} \vec{h}_{j}\right)

其中\Vert是拼接操作,\alpha_{ij}^{k}是第k組注意力機制計算出的權重係數,W^{(k)}是對應的輸入線性變換矩陣,最終輸出的節點特徵向量\vec{h}_{i}^{\prime}包含了KF^{\prime}個特徵。爲了減少輸出的特徵向量的維度,也可以將拼接操作替換爲平均操作。

\vec{h}_{i}^{\prime}=\sigma\left(\frac{1}{K} \sum_{k=1}^{K} \sum_{j \in {N}_{i}} \alpha_{i j}^{k} {W}^{k} \vec{h}_{j}\right)

下面是K=3的多頭注意力機制示意圖。不同顏色的箭頭表示不同注意力的計算過程,每個鄰居做三次注意力計算,每次attention計算就是一個普通的self-attention,輸出一個\vec{h}_{i}^{\prime},最後將三個不同的\vec{h}_{i}^{\prime}進行拼接或取平均,得到最終的\vec{h}_{i}^{\prime}

不同模型比較

  • GAT計算高效。self-attetion層可以在所有邊上並行計算,輸出特徵可以在所有節點上並行計算;不需要特徵分解或者其他內存耗費大的矩陣操作。單個head的GAT的時間複雜度爲O\left(\mid V\mid F F^{\prime}+\mid E\mid F^{\prime}\right)
  • 與GCN不同的是,GAT爲同一鄰域中的節點分配不同的重要性,提升了模型的性能。
  • 注意力機制以共享的方式應用於圖中的所有邊,因此它不依賴於對全局圖結構的預先訪問,也不依賴於對所有節點(特徵)的預先訪問(這是許多先前技術的限制)。
    • 不必要無向圖。如果邊i\rightarrow j不存在,可以忽略計算e_{ij}
    • 可以用於歸納學習;

評估

數據集

其中前三個引文網絡用於直推學習,第四個蛋白質交互網絡PPI用於歸納學習。

實驗設置

  • 直推學習

    • 兩層GAT模型,第一層多頭注意力K=8,輸出特徵維度F^{\prime}=8(共64個特徵),激活函數爲指數線性單元(ELU);
    • 第二層單頭注意力,計算C個特徵(C爲分類數),接softmax激活函數;
    • 爲了處理小的訓練集,模型中大量採用正則化方法,具體爲L2正則化;
    • dropout;
  • 歸納學習:

    • 三層GAT模型,前兩層多頭注意力K=4,輸出特徵維度F^{\prime}=256(共1024個特徵),激活函數爲指數非線性單元(ELU);
    • 最後一層用於多標籤分類,K=6,每個頭計算121個特徵,後接logistic sigmoid激活函數;
    • 不使用正則化和dropout;
    • 使用了跨越中間注意力層的跳躍連接。
    • batch_size = 2 graph

實驗結果

  • 不同數據集的分類準確率效果對比(Transductive)


  • 數據集PPI上的F1效果(歸納學習)
  • 可視化

核心代碼

GAT層代碼:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class GraphAttentionLayer(nn.Module):
    """
    Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
    """
    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.dropout = dropout
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat

        self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

        self.leakyrelu = nn.LeakyReLU(self.alpha)

    def forward(self, h, adj):
        Wh = torch.mm(h, self.W) # h.shape: (N, in_features), Wh.shape: (N, out_features)
        a_input = self._prepare_attentional_mechanism_input(Wh)
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))

        zero_vec = -9e15*torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)
        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        h_prime = torch.matmul(attention, Wh)

        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

    def _prepare_attentional_mechanism_input(self, Wh):
        N = Wh.size()[0] # number of nodes

        # Below, two matrices are created that contain embeddings in their rows in different orders.
        # (e stands for embedding)
        # These are the rows of the first matrix (Wh_repeated_in_chunks): 
        # e1, e1, ..., e1,            e2, e2, ..., e2,            ..., eN, eN, ..., eN
        # '-------------' -> N times  '-------------' -> N times       '-------------' -> N times
        # 
        # These are the rows of the second matrix (Wh_repeated_alternating): 
        # e1, e2, ..., eN, e1, e2, ..., eN, ..., e1, e2, ..., eN 
        # '----------------------------------------------------' -> N times
        # 
        
        Wh_repeated_in_chunks = Wh.repeat_interleave(N, dim=0)
        Wh_repeated_alternating = Wh.repeat(N, 1)
        # Wh_repeated_in_chunks.shape == Wh_repeated_alternating.shape == (N * N, out_features)

        # The all_combination_matrix, created below, will look like this (|| denotes concatenation):
        # e1 || e1
        # e1 || e2
        # e1 || e3
        # ...
        # e1 || eN
        # e2 || e1
        # e2 || e2
        # e2 || e3
        # ...
        # e2 || eN
        # ...
        # eN || e1
        # eN || e2
        # eN || e3
        # ...
        # eN || eN

        all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=1)
        # all_combinations_matrix.shape == (N * N, 2 * out_features)

        return all_combinations_matrix.view(N, N, 2 * self.out_features)

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'

GAT模型


import torch
import torch.nn as nn
import torch.nn.functional as F
from layers import GraphAttentionLayer, SpGraphAttentionLayer


class GAT(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):
        """Dense version of GAT."""
        super(GAT, self).__init__()
        self.dropout = dropout

        self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]
        for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)

        self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)

    def forward(self, x, adj):
        x = F.dropout(x, self.dropout, training=self.training)
        x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.elu(self.out_att(x, adj))
        return F.log_softmax(x, dim=1)

參考文章

圖神經網絡:圖注意力網絡(GAT) https://jjzhou012.github.io/blog/2020/01/28/Graph-Attention-Networks.html

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章