【文本分類】RCNN模型

RCNN模型也是用於文本分類的常用模型,其源論文爲Recurrent Convolutional Neural Networks for Text Classification

模型整體結構如下:
在這裏插入圖片描述
架構主要包括如下模塊:
(1)通過雙向RNN模型,得到每個token上下文的信息(隱層輸出):
(2)通過隱層輸出與原始embedding的拼接,得到擴展後的token信息;
(3)後面接於TextCNN的CNN、max-pooling和fc層,得到分類結果。

整個模型結構還是非常清晰的,下面給出pytorch的簡單實現:


import torch
import torch.nn as nn


Config = {"vob_size": 100,         # 字典尺寸
          "ebd_size": 50,            # 詞嵌入維度
          "hidden_size": 20,         # 字典尺寸
          "num_layer": 2,
          "bidirectiion": True,   # 雙向
          "drop":0.3,      # dropout比例
          "cnn_channel":100,   # 1D-CNN的output_channel
          "cnn_kernel": 3,    # 1D-CNN的卷積核
          "topk": 10,  # cnn的output結果取top-k
          "fc_hidden": 10,  # 全連接層的隱藏層
          "fc_cla": 4,  # 全連接層的輸出類別
          }


class LSTM_pool(nn.Module):
    def __init__(self):
        super(LSTM_pool, self).__init__()
        self.embedding = nn.Embedding(Config['vob_size'], Config['ebd_size'])
        self.lstm = nn.LSTM(
            input_size=Config['ebd_size'],
            hidden_size=Config['hidden_size'],
            num_layers=Config['num_layer'],
            bidirectional=True,
            batch_first=True,
            dropout=Config['drop']
        )

        self.cnn = nn.Sequential(
            nn.Conv1d(
                in_channels=Config['hidden_size'] * 2 + Config['ebd_size'],  # 詞向量和output維度做concat
                out_channels=Config['cnn_channel'],
                kernel_size=Config['cnn_kernel']),
            nn.BatchNorm1d(Config['cnn_channel']),
            nn.ReLU(inplace=True),

            nn.Conv1d(
                in_channels=Config['cnn_channel'],
                out_channels=Config['cnn_channel'],
                kernel_size=Config['cnn_kernel']),

            nn.BatchNorm1d(Config['cnn_channel']),
            nn.ReLU(inplace=True)

        )

        self.fc = nn.Sequential(
            nn.Linear(Config['topk'] * Config['cnn_channel'], Config['fc_hidden']),   # 2爲bidirectional的拼接結果
            nn.BatchNorm1d(Config['fc_hidden']),
            nn.ReLU(inplace=True),

            nn.Linear(Config['fc_hidden'], Config['fc_cla'])

        )

    @staticmethod
    def topk_pooling(x, k, dim):
        index = torch.topk(x, k, dim=dim)[1]
        return torch.gather(x, dim=dim, index=index)

    def forward(self, x):
        emb = self.embedding(x)
        out, _ = self.lstm(emb)    # (B, S, 2H)
        out = torch.cat([emb, out], dim=-1)   # (B, S, E) + (B, S, 2H) = (B, S, 2H+E)
        out = out.permute((0, 2, 1))    # (B, 2H+E, S)
        out = self.cnn(out)    # (B, C, S-m)
        x = self.topk_pooling(out, k=Config['topk'], dim=-1)   # sequence_len方向取top2,  (B, C, k)
        x = x.view((x.size(0), -1))    # (B, C*k)
        logits = self.fc(x)
        return logits
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章