DGL
兩個API
- message function(消息函數)
消息函數通過邊獲取變量:
(1)用 e.src.data
獲得這條邊出發節點的特徵信息
(2)用 e.dst.data
獲得這條邊目標節點的特徵信息
(3)用 e.data
獲得這條邊的特徵信息
消息函數可以獲得出發節點和目標節點的特徵信息,描述了需要發給目標節點做下一步計算的信息。
如上圖,消息函數把節點1和節點2的信息都發送給節點3,可以發送的信息包括 v1、v2 和 v3 以及每條邊上的消息。
- reduce function(累和函數)
目標節點在獲得其他節點以及邊的特徵信息之後,通過累和函數計算出一個新的表示。
如上圖,累和函數獲得了消息函數傳遞過來的信息 M13、M23 同時還有自身的節點信息。
- 圖神經網絡案例
import torch
import torch.nn as nn
class GCNMessage(nn.Module):
"""消息函數"""
def forward(self, edges):
"""
:param: edges, a batch of edges
:return: This computes a (batch of) message called 'msg'
using the source node's feature 'h'
"""
# 源點的批量特徵向量
return {"h": edges.src["h"]}
class GCNReduce(nn.Module):
"""累和函數"""
def __init__(self, in_feats, out_feats):
super().__init__()
self.linear = nn.Linear(in_feats, out_feats)
self.activation = nn.ReLU()
def forward(self, nodes):
"""
:param: nodes, a batch of nodes
:return: This computes the new 'h' features
by summing received 'msg' in each node's mailbox
"""
# 批量消息張量, nodes.mailbox["h"]
accum = torch.sum(nodes.mailbox["h"], dim=1)
h = self.linear(accum)
h = self.activation(h)
return {"h": h}
class GCN(nn.Module):
"""GCN Layer"""
def __init__(self, in_feats, out_feats):
self.msg_func = GCNMessage()
self.reduce_func = GCNReduce(in_feats, out_feats)
def forward(self, g, inputs):
"""
:param: g, the graph
:param: inputs, the input node features
"""
# first set the node features
g.ndata["h"] = inputs
# 全局更新
g.update_all(self.msg_func, self.reduce_func)
"""
Or
g.send(g.edges(), gcn_message)
g.recv(g.nodes(), gcn_reduce)
And there the gcn_message and gcn_reduce are functions.
"""
# Get the 'h' features and remove the node/edge states from the graph
return g.ndata.pop("h")
這裏的累和函數:
Builtin message passing functions
DGL 的消息傳遞主要使用兩個 API:
send(edges, message_func)
用於計算沿着給定邊的消息recv(nodes, reduce_func)
用於收集進入節點的消息,執行聚集等操作
使用 u
,v
和 e
分別表示 source nodes,destination nodes 和 edges
消息傳遞使用案例:
import dgl
import dgl.function as fn
import torch
# create a DGL Graph
g = ...
# each node has feature size 10
g.ndata['h'] = torch.randn((g.number_of_nodes(), 10))
# each edge has feature size 1
g.edata['w'] = torch.randn((g.number_of_edges(), 1))
# collect features from source nodes and aggregate them in destination nodes
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h_sum'))
# multiply source node features with edge weights and aggregate them in destination nodes
g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.max('m', 'h_max'))
# compute edge embedding by multiplying source and destination node embeddings
g.apply_edges(fn.u_mul_v('h', 'h', 'w_new'))
對於一元消息函數(e.g. copy_u
)需要一個輸入的特徵名和一個輸出的消息名;對於二元消息函數(e.g. u_mul_e
)需要兩個輸入特徵名和一個輸出消息名。對於 fn.u_mul_e('h', 'w', 'm')
是按如下函數定義:
def udf_u_mul_e(edges):
return {'m': edges.src['h'] * edges.data['w']}
對於 reduce function,需要給出一個輸入消息名和一個輸出節點特徵名,例如,fn.max('m', 'h_max')
是按如下定義:
def udf_max(nodes):
return {'h_max': torch.max(nodes.mailbox['m'], 1)[0]}
MultiGraphs
創建圖時需要設置 multigraph=True
g_multi = dgl.DGLGraph(multigraph=True)
g_multi.add_nodes(10)
g_multi.ndata['x'] = torch.randn(10, 2)
# edges, [(1, 0), (2, 0), (3, 0), ..., (9, 0), (1, 0)]
# two edges on 1->0
g_multi.add_edges(list(range(,, 10)), 0)
g_multi.add_edges(1, 0)
g_multi.edata['w'] = torch.randn(10, 2)
# set the first 1->0 edge's data
g_multi.edges[1].data['w'] = torch.zeros(1, 2)
MultiGraph 中的邊沒法通過節點 和 唯一確定,需要使用 edge_id
獲取邊的 id
# tensor([0, 9])
eid_10 = g_multi.edge_id(1, 0)
g_multi.edges[eid_10].data['w'] = torch.ones(len(eid_10), 2)