論文題目:Text Level Graph Neural Network for Text Classification
論文地址:https://arxiv.org/pdf/1910.02356.pdf
論文代碼:https://github.com/yenhao/text-level-gnn
發表時間:2019
論文簡介與動機
1)TextGCN爲整個數據集/語料庫構建一個異構圖(包括(待分類)文檔節點和單詞節點),邊的權重是固定的(單詞節點間的邊權重是兩個單詞的PMI,文檔-單詞節點間的邊權重是TF-IDF),固定權重限制了邊的表達能力,而且爲了獲取一個全局表示不得不使用一個非常大的連接窗口。因此,構建的圖非常大,而且邊非常多,模型由很大的內存消耗。
2)上篇博客也提到了,TextGCN這種類型的模型,無法爲新樣本(文本)進行分類(在線測試),因爲圖的架構和參數依賴於語料庫/數據集,訓練結束後就不能再修改了。(除非將新文本加入到語料庫中,更新圖的結構,重新訓練......一般不會這樣做,總之該類模型不能爲新文本進行分類)
本篇論文提出了一個新的基於GNN的模型來做文本分類,解決了上述兩個問題:
1)爲每個輸入文本/數據(text-level)都單獨構建一個圖,文本中的單詞作爲節點;而不是給整個語料庫/數據集(corpus-level)構建一個大圖(每個文本和單詞作爲節點)。在每個文本中,使用一個非常小的滑動窗口,文本中的每個單詞只與其左右的p個詞有邊相連(包括自己,自連接),而不是所有單詞節點全連接。
2)相同單詞節點的表示以及相同單詞對之間邊的權重全局(數據集/語料庫中的所有文本/數據)共享,通過文本級別圖的消息傳播機制進行更新。
這樣就可以消除單個輸入文本和整個語料庫/數據集的依賴負擔,支持在線測試(新文本測試);而且上下文窗口更小,邊數更少,內存消耗更小。
Text-Level-GNN模型
構建文本圖
對於給定的一個包含l
個詞的文本記爲,其中代表文本中第個單詞的表示,初始化一個全局共享的詞嵌入矩陣(使用預訓練詞向量初始化),每個單詞/節點的初始表示從該嵌入矩陣中查詢,嵌入矩陣作爲模型參數在訓練過程中更新。
爲每個輸入文本/數據構建一個圖,把文本中的單詞看作是節點,每個單詞和它左右相鄰的個單詞有邊相連(包括自己,自連接)。輸入文本的圖表示爲:
其中N和E是文本圖的節點集和邊集,每個單詞節點的表示,以及單詞節點間邊的權重分別來自兩個全局共享矩陣(模型參數,訓練過程中更新)。此外,對於訓練集中出現次數少於k(k=2)次的邊(詞對)均勻地映射到一個"公共邊",使得參數充分學習。
如上圖所示:一個文本Text Level Graph爲一個單獨的文本“he is very proud of you.”。爲了顯示方便,在這個圖中,爲節點“very”(節點和邊用紅色表示)設置,爲其他節點(用藍色表示)設置。在實際情況下,會話期間的值是唯一的。圖中的所有參數都來自圖底部顯示的全局共享表示矩陣。
與以往構建圖的方法相比,該方法可以極大地減少圖的節點和邊的規模。這意味着文本級圖形可以消耗更少的GPU內存。
消息傳遞機制
卷積可以從局部特徵中提取信息。在圖域中,卷積是通過頻譜方法或非頻譜方法實現的。在本文中,一種稱爲消息傳遞機制(MPM)的非頻譜方法被用於卷積。MPM首先從相鄰節點收集信息,並根據其原始表示形式和所收集的信息來更新其表示形式,其定義爲:
其中是節點從其鄰居接收到的消息;是一種歸約函數,它將每個維上的最大值組合起來以形成一個新的向量作爲輸出。代表原始文本中的最近個單詞的節點;是從節點到節點的邊緣權重,它可以訓練時更新;代表節點n先前的表示向量。節點n的可訓練的變量,指示應該保留多少的信息。代表節點更新後的表示。
MPM使節點的表示受到鄰域的影響,這意味着表示可以從上下文中獲取信息。因此,即使對於一詞多義,上下文中的精確含義也可以通過來自鄰居的加權信息的影響來確定。此外,文本級圖的參數取自全局共享矩陣,這意味着表示形式也可以像其他基於圖的模型一樣帶來全局信息。
最後,使用文本中所有節點的表示來預測文本的標籤:
其中是將向量映射到輸出空間的矩陣,是文本的節點集,是偏差。
訓練的目的是最小化真實標籤和預測標籤之間的交叉熵損失:
,其中是真實標籤的one-hot向量表示。
實驗結果
不同模型的對比實驗
數據集採用了R8,R52和Ohsumed。R8和R52都是路透社21578數據集的子集。
p值影響
消融實驗
(1)取消邊之間的權重,性能變差,說明爲邊設置權重較好。
(2)mean取代max
(3)去掉預訓練詞嵌入
核心代碼
獲取鄰居詞:https://github.com/yenhao/text-level-gnn/blob/master/utils.py
def get_word_neighbors_mp(text_tokens:list, neighbor_distance:int) :
print("\tGet word's neighbors")
with mp.Pool(mp.cpu_count()) as p:
return p.starmap(get_word_neighbor, map(lambda tokens: (tokens, neighbor_distance), text_tokens))
def get_word_neighbor(text_tokens: list, neighbor_distance: int) :
"""Get word token's adjacency neighbors with distance : neighbor_distance
Args:
text_tokens (list): A list of the tokens of sentences/texts from dataset.
neighbor_distance (int): The adjacency distance to consider as a neighbor.
Returns:
list: A nested list with 2 dimensions, which is a list of neighbor word tokens (2nd dim) for all tokens (1nd dim)
"""
text_len = len(text_tokens)
edge_neighbors = []
for w_idx in range(text_len):
skip_neighbors = []
# check before
for sk_i in range(neighbor_distance):
before_idx = w_idx -1 - sk_i
skip_neighbors.append(text_tokens[before_idx] if before_idx > -1 else 0)
# check after
for sk_i in range(neighbor_distance):
after_idx = w_idx +1 +sk_i
skip_neighbors.append(text_tokens[after_idx] if after_idx < text_len else 0)
edge_neighbors.append(skip_neighbors)
return edge_neighbors
TextLevelGNN層:https://github.com/yenhao/text-level-gnn/blob/master/model.py
class TextLevelGNN(nn.Module):
def __init__(self, num_nodes, node_feature_dim, class_num, embeddings=0, embedding_fix=False):
super(TextLevelGNN, self).__init__()
if type(embeddings) != int:
print("\tConstruct pretrained embeddings")
self.node_embedding = nn.Embedding.from_pretrained(embeddings, freeze=embedding_fix, padding_idx=0)
else:
self.node_embedding = nn.Embedding(num_nodes, node_feature_dim, padding_idx = 0)
# self.edge_weights = nn.Embedding((num_nodes-1) * (num_nodes-1) + 1, 1, padding_idx=0) # +1 is padding
self.edge_weights = nn.Embedding(num_nodes * num_nodes, 1) # +1 is padding
self.node_weights = nn.Embedding(num_nodes, 1, padding_idx=0) # Nn, node weight for itself
self.fc = nn.Sequential(
nn.Linear(node_feature_dim, class_num, bias=True),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Softmax(dim=1)
)
def forward(self, X, NX, EW):
"""
INPUT:
-------
X [tensor](batch, sentence_maxlen) : Nodes of a sentence
NX [tensor](batch, sentence_maxlen, neighbor_distance*2): Neighbor nodes of each nodes in X
EW [tensor](batch, sentence_maxlen, neighbor_distance*2): Neighbor weights of each nodes in X
OUTPUT:
-------
y [list] : Predicted Probabilities of each classes
"""
## Neighbor
# Neighbor Messages (Mn)
Mn = self.node_embedding(NX) # (BATCH, SEQ_LEN, NEIGHBOR_SIZE, EMBED_DIM)
# EDGE WEIGHTS
En = self.edge_weights(EW) # (BATCH, SEQ_LEN, NEIGHBOR_SIZE )
# get representation of Neighbors
Mn = torch.sum(En * Mn, dim=2) # (BATCH, SEQ_LEN, EMBED_DIM)
# Self Features (Rn)
Rn = self.node_embedding(X) # (BATCH, SEQ_LEN, EMBED_DIM)
## Aggregate information from neighbor
# get self node weight (Nn)
Nn = self.node_weights(X)
Rn = (1 - Nn) * Mn + Nn * Rn
# Aggragate node features for sentence
X = Rn.sum(dim=1)
y = self.fc(X)
return y
結論
本文提出了一個新的基於圖的文本分類模型,該模型使用文本級圖而不是整個語料庫的單個圖。實驗結果表明,我們的模型達到了最先進的性能,並且在內存消耗方面具有顯着優勢。