圖注意力網絡(GAT,GraphAttentionNetwork)

GAT(GRAPH ATTENTION NETWORKS)是一種使用了self attention機制圖神經網絡,該網絡使用類似transformer裏面self attention的方式計算圖裏面某個節點相對於每個鄰接節點的注意力,將節點本身的特徵和注意力特徵concate起來作爲該節點的特徵,在此基礎上進行節點的分類等任務。

下面是transformer self attention原理圖:



GAT使用了類似的流程計算節點的self attention,首先計算當前節點和每個鄰接節點的注意力score,然後使用該score乘以每個節點的特徵,累加起來並經過一個非線性映射,作爲當前節點的特徵。



Attention score公式表示如下:

這裏使用W矩陣將原始的特徵映射到一個新的空間,a代表self attention的計算,如前面圖2所示,這樣計算出兩個鄰接節點的attention score,也就是Eij,然後對所有鄰接節點的score進行softmax處理,得到歸一化的attention score。
代碼可以參考這個實現:https://github.com/gordicaleksa/pytorch-GAT
核心代碼:

    def forward(self, data):
        in_nodes_features, connectivity_mask = data  
        num_of_nodes = in_nodes_features.shape[0]
        in_nodes_features = self.dropout(in_nodes_features)
        # V
        nodes_features_proj = self.linear_proj(in_nodes_features).view(-1, self.num_of_heads, self.num_out_features)

        nodes_features_proj = self.dropout(nodes_features_proj)  
        # Q、K
        scores_source = torch.sum((nodes_features_proj * self.scoring_fn_source), dim=-1, keepdim=True)
        scores_target = torch.sum((nodes_features_proj * self.scoring_fn_target), dim=-1, keepdim=True)

        scores_source = scores_source.transpose(0, 1)
        scores_target = scores_target.permute(1, 2, 0)
        # Q * K
        all_scores = self.leakyReLU(scores_source + scores_target)
        all_attention_coefficients = self.softmax(all_scores + connectivity_mask)
        # Q * K * V
        out_nodes_features = torch.bmm(all_attention_coefficients, nodes_features_proj.transpose(0, 1))

        out_nodes_features = out_nodes_features.permute(1, 0, 2)
        # in_nodes_features + out_nodes_features(attention)
        out_nodes_features = self.skip_concat_bias(all_attention_coefficients, in_nodes_features, out_nodes_features)
        return (out_nodes_features, connectivity_mask)

該GAT的實現也包含在了PYG庫中,這個庫涵蓋了各種常見的圖神經網絡方面的論文算法實現。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章