DGL框架之 message function 和 reduce function 相關介紹

DGL

兩個API

在這裏插入圖片描述

  • message function(消息函數)

消息函數通過邊獲取變量:

(1)用 e.src.data 獲得這條邊出發節點的特徵信息

(2)用 e.dst.data 獲得這條邊目標節點的特徵信息

(3)用 e.data 獲得這條邊的特徵信息

消息函數可以獲得出發節點和目標節點的特徵信息,描述了需要發給目標節點做下一步計算的信息。

在這裏插入圖片描述
如上圖,消息函數把節點1和節點2的信息都發送給節點3,可以發送的信息包括 v1、v2 和 v3 以及每條邊上的消息。

  • reduce function(累和函數)

目標節點在獲得其他節點以及邊的特徵信息之後,通過累和函數計算出一個新的表示。

在這裏插入圖片描述
如上圖,累和函數獲得了消息函數傳遞過來的信息 M13、M23 同時還有自身的節點信息。

  • 圖神經網絡案例
import torch
import torch.nn as nn

class GCNMessage(nn.Module):
    """消息函數"""
    def forward(self, edges):
        """
        :param: edges, a batch of edges
        :return: This computes a (batch of) message called 'msg' 
        	using the source node's feature 'h'
        """
        # 源點的批量特徵向量
        return {"h": edges.src["h"]}
    
class GCNReduce(nn.Module):
    """累和函數"""
    def __init__(self, in_feats, out_feats):
        super().__init__()
        self.linear = nn.Linear(in_feats, out_feats)
        self.activation = nn.ReLU()
        
    def forward(self, nodes):
        """
        :param: nodes, a batch of nodes
        :return: This computes the new 'h' features 
        	by summing received 'msg' in each node's mailbox
        """
        # 批量消息張量, nodes.mailbox["h"]
        accum = torch.sum(nodes.mailbox["h"], dim=1)
        h = self.linear(accum)
        h = self.activation(h)
        return {"h": h}
    
class GCN(nn.Module):
    """GCN Layer"""
	def __init__(self, in_feats, out_feats):
        self.msg_func = GCNMessage()
        self.reduce_func = GCNReduce(in_feats, out_feats)
        
    def forward(self, g, inputs):
        """
        :param: g, the graph
        :param: inputs, the input node features
        """
        # first set the node features
        g.ndata["h"] = inputs
        # 全局更新
        g.update_all(self.msg_func, self.reduce_func)
        """
        Or
        g.send(g.edges(), gcn_message)
        g.recv(g.nodes(), gcn_reduce)
        And there the gcn_message and gcn_reduce are functions.
        """
        # Get the 'h' features and remove the node/edge states from the graph
        return g.ndata.pop("h")

這裏的累和函數:

hinew=f(ΣjiNhj)h^{new}_i = f ( \Sigma_{j \neq i}^N h_j )

Builtin message passing functions

DGL 的消息傳遞主要使用兩個 API:

  • send(edges, message_func) 用於計算沿着給定邊的消息
  • recv(nodes, reduce_func) 用於收集進入節點的消息,執行聚集等操作

使用 uve分別表示 source nodes,destination nodes 和 edges

消息傳遞使用案例:

import dgl
import dgl.function as fn
import torch

# create a DGL Graph
g = ...
# each node has feature size 10
g.ndata['h'] = torch.randn((g.number_of_nodes(), 10))
# each edge has feature size 1
g.edata['w'] = torch.randn((g.number_of_edges(), 1))

# collect features from source nodes and aggregate them in destination nodes
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h_sum'))

# multiply source node features with edge weights and aggregate them in destination nodes
g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.max('m', 'h_max'))

# compute edge embedding by multiplying source and destination node embeddings
g.apply_edges(fn.u_mul_v('h', 'h', 'w_new'))

對於一元消息函數(e.g. copy_u)需要一個輸入的特徵名和一個輸出的消息名;對於二元消息函數(e.g. u_mul_e)需要兩個輸入特徵名和一個輸出消息名。對於 fn.u_mul_e('h', 'w', 'm') 是按如下函數定義:

def udf_u_mul_e(edges):
	return {'m': edges.src['h'] * edges.data['w']}

對於 reduce function,需要給出一個輸入消息名和一個輸出節點特徵名,例如,fn.max('m', 'h_max') 是按如下定義:

def udf_max(nodes):
    return {'h_max': torch.max(nodes.mailbox['m'], 1)[0]}

MultiGraphs

創建圖時需要設置 multigraph=True

g_multi = dgl.DGLGraph(multigraph=True)
g_multi.add_nodes(10)
g_multi.ndata['x'] = torch.randn(10, 2)

# edges, [(1, 0), (2, 0), (3, 0), ..., (9, 0), (1, 0)]
# two edges on 1->0
g_multi.add_edges(list(range(,, 10)), 0)
g_multi.add_edges(1, 0)

g_multi.edata['w'] = torch.randn(10, 2)
# set the first 1->0 edge's data
g_multi.edges[1].data['w'] = torch.zeros(1, 2)

MultiGraph 中的邊沒法通過節點 uuvv 唯一確定,需要使用 edge_id 獲取邊的 id

# tensor([0, 9])
eid_10 = g_multi.edge_id(1, 0)
g_multi.edges[eid_10].data['w'] = torch.ones(len(eid_10), 2)

Reference

DGL at a Glance

DGL Basics

Builtin message passing functions

DGL 作者答疑!關於 DGL 你想知道的都在這裏

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章