【PyG入門學習】三:信息傳遞機制

1.理論基礎

將普通的卷積過程推廣到非規則數據領域一般是通過鄰域聚合或者信息傳遞機制。xi(k1)RFx^{(k-1)}_i∈R^F表示在第k-1層節點i的節點特徵,ej,iRDe_{j,i}∈R^D表示從節點j到節點i的邊的特徵(可選參數),那麼圖神經網絡中的信息傳遞機制就可以表示爲:
在這裏插入圖片描述
其中□ 表示一種可微的、置換不變的函數(也就是後面的聚合模式),比如求和、取均值或者最大值,γ\gammaϕ\phi均爲可微的函數,比如MLP多層感知機。上述公式相當於就是把一個節點的鄰域節點的特徵聚合到當前節點上面,最外層的γ\gamma函數就類似於我們常見的非線性激活函數,聚合的信息分爲兩部分,第一部分是上一層中該節點自身的特徵信息,第二部分是上一層中,該節點和鄰域節點邊上的傳遞信息。

2.“信息傳遞”基類

Pytorch-Geometric中提供了一個基類torch_geometric.nn.MessagePassing,它自身已經實現了信息傳遞機制來更有效的創建信息傳遞機制的圖神經網絡,只要將其作爲一個基類繼承創建自己的類即可。使用的時候只需要定義函數ϕ\phi
比如message(),和函數γ\gamma比如update();同時需要指定聚合方式比如aggr='add'aggr='mean'或者aggr='max'。在這個基類中比較重要的幾個地方如下:
(1)torch_geometric.nn.MessagePassing(aggr="add", flow="source_to_target")
定義三種聚合模式中的一種以及信息傳遞的方向,默認是從源節點到目標節點,比如一個有向邊1->2,源節點是1,目標節點是2。
(2)torch_geometric.nn.MessagePassing.propagate(edge_index, size=None, dim=0, **kwargs)
調用該函數會進行信息的傳播計算過程,參數爲邊的數據以及其他在構建信息傳遞過程和更新節點嵌入向量的數據參數(這裏的額外的數據參數並不會在這該函數用到,而是傳遞到之後的函數中)。值得注意的是該方法不僅限於shape=[N, N]的鄰接矩陣,也可以用於一些稀疏化的矩陣,對於稀疏化矩陣如果創建完整的鄰接矩陣對於空間浪費比較大,所以只會存儲其中非0元素(存儲該元素的行座標和列座標),比如二分圖;對於矩陣格式shape=[N, M]需要傳遞參數size=(N, M),如果該參數爲None,就會默認爲是規則的鄰接矩陣。對於二分圖而言,含有兩個獨立的節點索引,所以傳遞參數的方式可以類似於x=(x_N, x_M)的形式。
(3)torch_geometric.nn.MessagePassing.message()
對到達節點i的信息進行構建,相當於函數ϕ\phi,也就是計算出所有鄰居節點的應該傳遞過來的信息量爲多少;根據信息傳遞方向的不同(詳見(1)中的參數flow),節點對的選取方式也不同。值得注意的是,該函數所需的參數是來自於最初傳遞給propagate()函數的參數中的任何參數,換句話說,你要想在message中使用圖的某些屬性參數,必須在propagate()中先傳遞。另外,傳遞給propagate()tensor會通過增加_i_j的方式來創建新的變量名,該變量作爲tensor分別映射到節點i和節點j的值。
(4)torch_geometric.nn.MessagePassing.update()
將聚合函數後的結果作爲輸入計算出更新值。接受聚合過程的輸出結果作爲第一個參數,和其他任意之前傳遞給propagate()的參數。

3.GCN層的實現

從數學角度看,GCN層即:
在這裏插入圖片描述
鄰居節點的特徵首先通過一個權重矩陣的轉換,然後通過它們的度進行標準化,最後進行求和。具體步驟如下:
1.在鄰接矩陣中增加自環
2.對節點特徵進行一次線性轉化(利用linear層實現)
3.計算標準化係數
4.對節點特徵進行標準化(函數ϕ\phi
5.對相鄰節點特徵進行求和("add"聚合方式)
6.得到新的節點嵌入

import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
from torch_geometric.datasets import TUDataset

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add')
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # X: [N, in_channels]
        # edge_index: [2, E]

        # 1.在鄰接矩陣中增加自環
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # 2.對節點特徵進行一個非線性轉換
        # x的維度會由[N, in_channels]轉換爲[N, out_channels]
        x = self.lin(x)

        # 3.計算標準化係數
        # edge_index的第一個向量作爲行座標,第二個向量作爲列座標
        row, col = edge_index
        deg = degree(row, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-1/2)
        # norm的第一個元素就是edge_index中的第一列(第一條邊)上的標準化係數
        # tensor的乘法爲對應元素乘法,tensor1[tensor2]後的維度與tensor2一致
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # 4-6步的開始標誌,內部實現了message-AGGREGATE-update
        return self.propagate(edge_index, size=(x.size(0), x.size(1)), x=x, norm=norm)

    def message(self, x_j, norm):
        # x_j的維度爲[E, out_channels]

        # 4.進行傳遞消息的構造,將標準化係數乘以鄰域節點的特徵信息得到傳遞信息
        return norm.view(-1, 1) * x_j

    def update(self, aggr_out):
        # aggr_out的維度爲[N, out_channels]

        # 6.更新新的節點嵌入,這裏沒有做任何多餘的映射過程
        return aggr_out

# 實例化對象
conv = GCNConv(16, 32)
# 默認爲調用對象的forward函數
x = conv(x, edge_index)

對於上面的代碼,GCNConv全部的計算流程都在forward()函數中,在該函數中,前三步是明確計算出來,但是第4-6步是隱含在propagate()函數中進行調用,propagate()函數會調用重載後message()函數和update()函數,並且自身實現了聚合過程。下面測試一下x_j的取值:
(1)取消linear過程
(2)在message函數中輸出x_j
初始化信息爲:

# 構建數據
edge_index = torch.tensor([
    [0, 1, 1, 2],
    [1, 0, 2, 1]
], dtype=torch.long)
x = torch.tensor([
    [0, 0, 0],
    [1, 1, 1],
    [2, 2, 2]
], dtype=torch.float)

輸出的x_j爲:

tensor([[0., 0., 0.],
        [1., 1., 1.],
        [1., 1., 1.],
        [2., 2., 2.],
        [0., 0., 0.],
        [1., 1., 1.],
        [2., 2., 2.]])

所以x_j對應的節點序列爲[0,1,1,2,0,1,2],而egde_index增加自環之後,是:

tensor([[0, 1, 1, 2, 0, 1, 2],
        [1, 0, 2, 1, 0, 1, 2]])

所以x_j對應第一行節點的特徵信息。

【注】後面將會對PyG自帶的例子進行分析以及相關API。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章