RCNN模型也是用於文本分類的常用模型,其源論文爲Recurrent Convolutional Neural Networks for Text Classification。
模型整體結構如下:
架構主要包括如下模塊:
(1)通過雙向RNN模型,得到每個token上下文的信息(隱層輸出):
(2)通過隱層輸出與原始embedding的拼接,得到擴展後的token信息;
(3)後面接於TextCNN的CNN、max-pooling和fc層,得到分類結果。
整個模型結構還是非常清晰的,下面給出pytorch的簡單實現:
import torch
import torch.nn as nn
Config = {"vob_size": 100, # 字典尺寸
"ebd_size": 50, # 詞嵌入維度
"hidden_size": 20, # 字典尺寸
"num_layer": 2,
"bidirectiion": True, # 雙向
"drop":0.3, # dropout比例
"cnn_channel":100, # 1D-CNN的output_channel
"cnn_kernel": 3, # 1D-CNN的卷積核
"topk": 10, # cnn的output結果取top-k
"fc_hidden": 10, # 全連接層的隱藏層
"fc_cla": 4, # 全連接層的輸出類別
}
class LSTM_pool(nn.Module):
def __init__(self):
super(LSTM_pool, self).__init__()
self.embedding = nn.Embedding(Config['vob_size'], Config['ebd_size'])
self.lstm = nn.LSTM(
input_size=Config['ebd_size'],
hidden_size=Config['hidden_size'],
num_layers=Config['num_layer'],
bidirectional=True,
batch_first=True,
dropout=Config['drop']
)
self.cnn = nn.Sequential(
nn.Conv1d(
in_channels=Config['hidden_size'] * 2 + Config['ebd_size'], # 詞向量和output維度做concat
out_channels=Config['cnn_channel'],
kernel_size=Config['cnn_kernel']),
nn.BatchNorm1d(Config['cnn_channel']),
nn.ReLU(inplace=True),
nn.Conv1d(
in_channels=Config['cnn_channel'],
out_channels=Config['cnn_channel'],
kernel_size=Config['cnn_kernel']),
nn.BatchNorm1d(Config['cnn_channel']),
nn.ReLU(inplace=True)
)
self.fc = nn.Sequential(
nn.Linear(Config['topk'] * Config['cnn_channel'], Config['fc_hidden']), # 2爲bidirectional的拼接結果
nn.BatchNorm1d(Config['fc_hidden']),
nn.ReLU(inplace=True),
nn.Linear(Config['fc_hidden'], Config['fc_cla'])
)
@staticmethod
def topk_pooling(x, k, dim):
index = torch.topk(x, k, dim=dim)[1]
return torch.gather(x, dim=dim, index=index)
def forward(self, x):
emb = self.embedding(x)
out, _ = self.lstm(emb) # (B, S, 2H)
out = torch.cat([emb, out], dim=-1) # (B, S, E) + (B, S, 2H) = (B, S, 2H+E)
out = out.permute((0, 2, 1)) # (B, 2H+E, S)
out = self.cnn(out) # (B, C, S-m)
x = self.topk_pooling(out, k=Config['topk'], dim=-1) # sequence_len方向取top2, (B, C, k)
x = x.view((x.size(0), -1)) # (B, C*k)
logits = self.fc(x)
return logits