文本分類論文及pytorch版復現(五):TextLevelGNN

Text Level Graph Neural Network for Text Classification

一、模型

圖示:

定義:

消息傳遞機制:

分類器:

損失函數:

二、代碼

from torch import nn, tensor
import torch.nn.functional as F
import numpy as np
import torch


class TextLevelGNN(nn.Module):

    def __init__(self):
        super(TextLevelGNN, self).__init__()
        num_nodes = 4904
        embedding_dim = 300
        num_classes = 14

        self.R = nn.Embedding(num_nodes + 1, embedding_dim, padding_idx=0)
        self.E = nn.Embedding(num_nodes * num_nodes + 1, 1, padding_idx=0)
        self.N = nn.Embedding(num_nodes + 1, 1, padding_idx=0)
        self.fc = nn.Sequential(
            nn.Linear(embedding_dim, num_classes, bias=True),
            nn.ReLU(inplace=True),
            nn.Softmax(dim=1),
            nn.Dropout(0.5)
        )

    def forward(self, master_nodes, slave_nodes_list, slave_edges_list):
        Rn = self.R(master_nodes)
        Ra = self.R(slave_nodes_list)
        Ean = self.E(slave_edges_list)
        Mn = (Ra * Ean).max(dim=2)[0]
        Nn = self.N(master_nodes)
        x = (1 - Nn) * Mn + Nn * Rn
        x = self.fc(x.sum(dim=1))
        return x


if __name__ == '__main__':
    num_nodes = 10000
    batch_size = 64
    seq_len = 1000
    window_size = 2
    embedding_dim = 300
    num_classes = 10
    master_nodes = tensor(np.random.randint(0, num_nodes + 1, (batch_size, seq_len)), dtype=torch.long)
    slave_nodes_list = tensor(np.random.randint(0, num_nodes + 1, (batch_size, seq_len, window_size * 2)),
                              dtype=torch.long)
    slave_edges_list = torch.randint(0, num_nodes * num_nodes + 1, (batch_size, seq_len, window_size * 2))

    model = TextLevelGNN()
    y = model(master_nodes, slave_nodes_list, slave_edges_list)
    print(y.shape)

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章