Pytorch中GNN的基類torch_geometric.nn.conv.MessagePassing

MessagePassing是torch_geometric中GNN模型的基類,實現了下面的消息傳遞公式

 要繼承這個類,需要複寫三個函數:

propagate(edge_index, size=None)

message()

消息傳遞分兩種方式,默認的是source_to_target

update()

其中propagate在執行的過程中會調用message和update

 。。。
#source=>target的消息傳播

 out = self.message(*message_args)
#out爲source頂點,out的shape爲[E,channel],其中E爲邊的條數,channel爲頂點embedding的維度

 out = scatter_(self.aggr, out, edge_index[i], dim, dim_size=size[i])
#將關聯邊的信息加(默認‘add’)到target的頂點上,out的shape爲[V,channel],其中V爲target頂點的個數

 out = self.update(out, *update_args)

 return out

假設頂點V1和頂點v2,v3,v4,.....vn有邊相連,propagate做的事情是將v2,v3,v4,.....vn的信息加(默認‘add’,也可以‘mean’,‘max’)到v1上

GCN的實現,三個函數都是在MessagePassing的基礎上實現的。

唯一關鍵的一步是norm函數,根據GCN的信息傳播的公式,計算鄰接矩陣和對角度矩陣。

                                                            

class GCNConv(MessagePassing):

    def __init__(self, in_channels, out_channels, improved=False, cached=False,
                 bias=True, **kwargs):
        super(GCNConv, self).__init__(aggr='add', **kwargs)
        #略
    @staticmethod
    def norm(edge_index, num_nodes, edge_weight=None, improved=False,
             dtype=None):
        #略
        #最關鍵的只有這一步,計算鄰接矩陣和對角度矩陣,根據GCN的信息傳播的公式
        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

    def forward(self, x, edge_index, edge_weight=None):
        """"""
        x = torch.matmul(x, self.weight)

        #略去代碼 主要是設置是否緩存
        edge_index, norm = self.cached_result

        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

    def update(self, aggr_out):
        if self.bias is not None:
            aggr_out = aggr_out + self.bias
        return aggr_out

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章