圖神經網絡15-Text-Level-GNN:基於文本級GNN的文本分類模型

論文題目:Text Level Graph Neural Network for Text Classification
論文地址:https://arxiv.org/pdf/1910.02356.pdf
論文代碼:https://github.com/yenhao/text-level-gnn
發表時間:2019

論文簡介與動機

1)TextGCN爲整個數據集/語料庫構建一個異構圖(包括(待分類)文檔節點和單詞節點),邊的權重是固定的(單詞節點間的邊權重是兩個單詞的PMI,文檔-單詞節點間的邊權重是TF-IDF),固定權重限制了邊的表達能力,而且爲了獲取一個全局表示不得不使用一個非常大的連接窗口。因此,構建的圖非常大,而且邊非常多,模型由很大的內存消耗。

2)上篇博客也提到了,TextGCN這種類型的模型,無法爲新樣本(文本)進行分類(在線測試),因爲圖的架構和參數依賴於語料庫/數據集,訓練結束後就不能再修改了。(除非將新文本加入到語料庫中,更新圖的結構,重新訓練......一般不會這樣做,總之該類模型不能爲新文本進行分類)

本篇論文提出了一個新的基於GNN的模型來做文本分類,解決了上述兩個問題:

1)爲每個輸入文本/數據(text-level)都單獨構建一個圖,文本中的單詞作爲節點;而不是給整個語料庫/數據集(corpus-level)構建一個大圖(每個文本和單詞作爲節點)。在每個文本中,使用一個非常小的滑動窗口,文本中的每個單詞只與其左右的p個詞有邊相連(包括自己,自連接),而不是所有單詞節點全連接。

2)相同單詞節點的表示以及相同單詞對之間邊的權重全局(數據集/語料庫中的所有文本/數據)共享,通過文本級別圖的消息傳播機制進行更新。

這樣就可以消除單個輸入文本和整個語料庫/數據集的依賴負擔,支持在線測試(新文本測試);而且上下文窗口更小,邊數更少,內存消耗更小。

Text-Level-GNN模型

構建文本圖

對於給定的一個包含l個詞的文本記爲T=\{{r_{1},r_{2},...,r_{l}}\},其中r_{i}代表文本中第i個單詞的表示,初始化一個全局共享的詞嵌入矩陣(使用預訓練詞向量初始化),每個單詞/節點的初始表示從該嵌入矩陣中查詢,嵌入矩陣作爲模型參數在訓練過程中更新。

爲每個輸入文本/數據構建一個圖,把文本中的單詞看作是節點,每個單詞和它左右相鄰的p個單詞有邊相連(包括自己,自連接)。輸入文本T的圖表示爲:
N=\{r_{i}|i\in[1,l] \}
E=\{e_{ij}|i\in[1,l] \}
其中N和E是文本圖的節點集和邊集,每個單詞節點的表示,以及單詞節點間邊的權重分別來自兩個全局共享矩陣(模型參數,訓練過程中更新)。此外,對於訓練集中出現次數少於k(k=2)次的邊(詞對)均勻地映射到一個"公共邊",使得參數充分學習。

如上圖所示:一個文本Text Level Graph爲一個單獨的文本“he is very proud of you.”。爲了顯示方便,在這個圖中,爲節點“very”(節點和邊用紅色表示)設置p= 2,爲其他節點(用藍色表示)設置p= 1。在實際情況下,會話期間p的值是唯一的。圖中的所有參數都來自圖底部顯示的全局共享表示矩陣。

與以往構建圖的方法相比,該方法可以極大地減少圖的節點和邊的規模。這意味着文本級圖形可以消耗更少的GPU內存。

消息傳遞機制

卷積可以從局部特徵中提取信息。在圖域中,卷積是通過頻譜方法或非頻譜方法實現的。在本文中,一種稱爲消息傳遞機制(MPM)的非頻譜方法被用於卷積。MPM首先從相鄰節點收集信息,並根據其原始表示形式和所收集的信息來更新其表示形式,其定義爲:

其中M_{n}\in R^{d}是節點n從其鄰居接收到的消息;max是一種歸約函數,它將每個維上的最大值組合起來以形成一個新的向量作爲輸出。N_{n}^{p}代表原始文本中n的最近p個單詞的節點;e_{an}\in R^{1}是從節點a到節點n的邊緣權重,它可以訓練時更新;r_{n}\in R^{d}代表節點n先前的表示向量。\eta_{n} \in R^{1}節點n的可訓練的變量,指示應該保留多少r_{n}的信息。r_{n}^{'}代表節點n更新後的表示。

MPM使節點的表示受到鄰域的影響,這意味着表示可以從上下文中獲取信息。因此,即使對於一詞多義,上下文中的精確含義也可以通過來自鄰居的加權信息的影響來確定。此外,文本級圖的參數取自全局共享矩陣,這意味着表示形式也可以像其他基於圖的模型一樣帶來全局信息。

最後,使用文本中所有節點的表示來預測文本的標籤:
y_{i}=softmax(Relu(W\sum_{n \in N_{}i}r_{n}^{'}+b))

其中W\in R^{dxc}是將向量映射到輸出空間的矩陣,N_{i}是文本i的節點集,b\in R^{c}是偏差。
訓練的目的是最小化真實標籤和預測標籤之間的交叉熵損失:

loss=-g_{i}logy_{i},其中g_{i}是真實標籤的one-hot向量表示。

實驗結果

不同模型的對比實驗

數據集採用了R8,R52和Ohsumed。R8和R52都是路透社21578數據集的子集。


p值影響

消融實驗

(1)取消邊之間的權重,性能變差,說明爲邊設置權重較好。
(2)mean取代max
(3)去掉預訓練詞嵌入


核心代碼

獲取鄰居詞:https://github.com/yenhao/text-level-gnn/blob/master/utils.py

def get_word_neighbors_mp(text_tokens:list, neighbor_distance:int) :
    print("\tGet word's neighbors")
    with mp.Pool(mp.cpu_count()) as p:
        return p.starmap(get_word_neighbor, map(lambda tokens: (tokens, neighbor_distance), text_tokens))

def get_word_neighbor(text_tokens: list, neighbor_distance: int) :
    """Get word token's adjacency neighbors with distance : neighbor_distance
    Args:
        text_tokens (list): A list of the tokens of sentences/texts from dataset.
        neighbor_distance (int): The adjacency distance to consider as a neighbor.
    Returns:
        list: A nested list with 2 dimensions, which is a list of neighbor word tokens (2nd dim) for all tokens (1nd dim)
    """
    text_len = len(text_tokens)

    edge_neighbors = []
    for w_idx in range(text_len):
        skip_neighbors = []
        # check before
        for sk_i in range(neighbor_distance):
            before_idx = w_idx -1 - sk_i
            skip_neighbors.append(text_tokens[before_idx] if before_idx > -1 else 0)

        # check after
        for sk_i in range(neighbor_distance):
            after_idx = w_idx +1 +sk_i
            skip_neighbors.append(text_tokens[after_idx] if after_idx < text_len else 0)

        edge_neighbors.append(skip_neighbors)
    return edge_neighbors

TextLevelGNN層:https://github.com/yenhao/text-level-gnn/blob/master/model.py

class TextLevelGNN(nn.Module):

    def __init__(self, num_nodes, node_feature_dim, class_num, embeddings=0, embedding_fix=False):
        super(TextLevelGNN, self).__init__()

        if type(embeddings) != int:
            print("\tConstruct pretrained embeddings")
            self.node_embedding = nn.Embedding.from_pretrained(embeddings, freeze=embedding_fix, padding_idx=0)
        else:
            self.node_embedding = nn.Embedding(num_nodes, node_feature_dim, padding_idx = 0)

        # self.edge_weights = nn.Embedding((num_nodes-1) * (num_nodes-1) + 1, 1, padding_idx=0) # +1 is padding
        self.edge_weights = nn.Embedding(num_nodes * num_nodes, 1) # +1 is padding
        self.node_weights = nn.Embedding(num_nodes, 1, padding_idx=0) # Nn, node weight for itself

        self.fc = nn.Sequential(
            nn.Linear(node_feature_dim, class_num, bias=True),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Softmax(dim=1)
        )

    def forward(self, X, NX, EW):
        """
        INPUT:
        -------
        X  [tensor](batch, sentence_maxlen)               : Nodes of a sentence
        NX [tensor](batch, sentence_maxlen, neighbor_distance*2): Neighbor nodes of each nodes in X
        EW [tensor](batch, sentence_maxlen, neighbor_distance*2): Neighbor weights of each nodes in X
        OUTPUT:
        -------
        y  [list] : Predicted Probabilities of each classes
        """
        ## Neighbor
        # Neighbor Messages (Mn)
        Mn = self.node_embedding(NX) # (BATCH, SEQ_LEN, NEIGHBOR_SIZE, EMBED_DIM)

        # EDGE WEIGHTS
        En = self.edge_weights(EW) # (BATCH, SEQ_LEN, NEIGHBOR_SIZE )

        # get representation of Neighbors
        Mn = torch.sum(En * Mn, dim=2) # (BATCH, SEQ_LEN, EMBED_DIM)

        # Self Features (Rn)
        Rn = self.node_embedding(X) # (BATCH, SEQ_LEN, EMBED_DIM)

        ## Aggregate information from neighbor
        # get self node weight (Nn)
        Nn = self.node_weights(X)
        Rn = (1 - Nn) * Mn + Nn * Rn

        # Aggragate node features for sentence
        X = Rn.sum(dim=1)

        y = self.fc(X)
        return y

結論

本文提出了一個新的基於圖的文本分類模型,該模型使用文本級圖而不是整個語料庫的單個圖。實驗結果表明,我們的模型達到了最先進的性能,並且在內存消耗方面具有顯着優勢。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章