多標籤分類:A Deep Reinforced Sequence-to-Set Model for Multi-Label Classification

文章地址:https://arxiv.org/pdf/1809.03118.pdf

代碼地址:https://github.com/lancopku/Seq2Set

文章標題:A Deep Reinforced Sequence-to-Set Model for Multi-Label Classification(多標籤分類的深度增強序列集模型)ACL2019

Abstract

多標籤分類(MLC)旨在預測給定實例的一組標籤。基於預先定義的標籤順序,通過最大似然估計方法訓練的序列-序列序列(Seq2Seq)模型已成功地應用於MLC任務,並顯示出強大的能力來捕獲標籤之間的高階相關性。然而,輸出標籤本質上是一個無序集,而不是有序序列。這種不一致性往往會導致一些棘手的問題,如對標籤順序的敏感性。爲了解決這個問題,我們提出了一個簡單而有效的序列到集合模型。提出的模型通過強化學習進行訓練,其中獎勵反饋被設計成獨立於標籤順序。通過這種方式,我們可以減少模型對標籤順序的依賴,並捕獲標籤之間的高階相關性。大量的實驗表明,我們的方法可以大大超過競爭的基線,以及有效地降低標籤順序的敏感性。

一、Introduction

多標籤分類(MLC)旨在爲每個樣本分配多個標籤。它可以應用於許多真實的場景,如文本分類(Schapire和Singer, 2000)和信息檢索(Gopal和Yang, 2010)。由於標籤之間的複雜依賴性,如何有效地捕獲標籤之間的高階相關性是MLC任務的關鍵挑戰(Zhang and Zhou, 2014)。

在涉及到獲取標籤之間的高階相關性時,有一條研究路線側重於探索標籤空間的層次結構(Prabhu and Varma, 2014; Jernite et al.,2017; Peng et al., 2018; Singh et al., 2018),,而另一行則努力擴展特定的學習算法(Zhang and Zhou, 2006; Baker and Korhonen, 2017; Liu et al., 2017)。然而,這些工作往往導致棘手的計算成本(Chen et al., 2017)。

最近,基於預先定義的標籤順序,Nam et al. (2017); Yang et al. (2018)成功地將sequence-to-sequence (Seq2Seq)模型應用到MLC任務中,顯示出其強大的捕獲高階標籤關聯的能力,並取得了優異的性能。然而,Seq2Seq模型在MLC任務上存在一些棘手的缺陷。輸出標籤本質上是一個帶有swapping-invariance(意味着交換集合中的任何兩個元素都沒有區別)的無序集,而不是一個有序序列。這種不一致性通常會導致一些棘手的問題,例如對標籤順序的敏感性。之前的工作(Vinyals et al., 2016)已經表明,順序對Seq2Seq模型的性能有很大的影響。因此,分類器的性能對預先定義的標籤順序非常敏感。此外,即使該模型準確預測了所有的真標籤,但由於與預先定義的標籤序列的順序不一致,仍可能導致不合理的訓練損失。

因此,在本研究中,我們提出了一種簡單而有效的序列-集合模型,旨在減輕模型對標籤順序的依賴。我們使用強化學習(RL) (Sutton et al., 1999)來指導模型訓練,而不是最大化預先定義的標籤序列的日誌可能性。設計的獎勵不僅全面評價了輸出標籤的質量,而且滿足了集的切換不變性,減少了模型對標籤順序的依賴。

本文的主要貢獻總結如下:

  • 提出了一種簡單有效的基於強化學習的序列集(Seq2Set)模型,該模型不僅捕獲了標籤之間的相關性,而且減輕了對標籤順序的依賴。
  • 實驗結果表明,我們的Seq2Set模型的性能大大優於基線。進一步的分析表明,我們的方法可以有效地降低模型對標籤順序的敏感性。

二、Methodology

2.1 Overview

這裏我們定義了一些必要的符號並描述了MLC任務。給定一個文本序列x包含m個詞,多標籤分類任務的目標是分配一個子集y包含n個標籤在總標籤集y到x。從序列的角度學習,一旦輸出標籤的順序是預定義的,多標籤分類任務可以被視爲目標標籤序列的生成y條件在源文本序列x。

2.2 Neural Sequence-to-Set Model

我們提出的Seq2Set模型由編碼器E和集合解碼器D組成,具體介紹如下。

(1)Encoder E
我們將編碼器E實現爲一個雙向LSTM。給定輸入文本(x1,…, xm),編碼器計算每個詞的隱藏狀態如下:
在這裏插入圖片描述
其中e(xi)爲xi的嵌入。第i個單詞的最終表示是hi,其中分號表示向量連接。

(2)Set decoder D
由於LSTM強大的能力來建模序列依賴性,我們也實現了D作爲一個LSTM模型來捕獲標籤之間的高階相關性。實際上,第t時刻集合解碼器D的隱藏狀態st計算爲:
在這裏插入圖片描述
在[e(yt-1); ct]表示向量的級聯e(yt-1)和ct, e(yt-1)是標籤的嵌入yt-1在上一個時間步生成的,ct是通過注意機制獲得的上下文向量。讀者可以參考Bahdanau等人(2015)瞭解更多細節。最後,集合解碼器D從輸出概率分佈中對標籤yt進行採樣,計算如下:
在這裏插入圖片描述
其中W1、W2、U爲可訓練參數,f爲非線性激活函數,其It是爲防止D產生重複標籤的掩碼向量,
在這裏插入圖片描述

2.3 Model Training

(1)MLC as a RL Problem
爲了減輕模型對標籤順序的依賴,這裏我們將MLC任務建模爲一個RL問題。我們的集合解碼器D可以看作是一個代理,它在t時刻的狀態是當前生成的標籤(y1,…, yt-1)。由參數D定義的隨機策略決定動作,即對下一個標籤的預測。一旦生成完整的標籤序列y, 代理D將得到獎勵r。訓練目標是最小化負的期望獎勵,具體如下:
在這裏插入圖片描述
在我們的模型中,我們使用了自批判策略梯度算法(Rennie et al., 2017)。對於minibatch中的每個訓練樣本,Eq.(6)的梯度近似爲:
在這裏插入圖片描述
其中ys爲概率分佈p採樣的標籤序列,yg爲貪婪搜索算法生成的標籤序列。Eq.(7)中的r(yg)爲基線,其目的是降低梯度估計的方差,增強模型訓練和測試的一致性,緩解exposure bias(Ranzato et al., 2016)。

(2)Reward Design
理想的獎勵應該是對生成的標籤質量的良好度量。此外,爲了使模型不受標籤順序的嚴格限制,還應滿足輸出標籤集的swappingconstant。爲此,我們將生成的標籤y與ground-truth標籤y*進行比較,設計出F1的積分作爲獎勵r。
在這裏插入圖片描述
我們也嘗試了其他的獎勵設計,比如漢明精度。結果表明,基於F1分數的獎勵是最佳的綜合表現。

三、Experiments

3.1 Datasets

我們在RCV1-V2語料庫上進行實驗(Lewis et al., 2004),該語料庫包含大量手動分類的新聞專線故事。標籤的總數是103個。Yang等(2018)也採用了同樣的數據分解方法。

3.2 Settings

我們根據微f1分數調整驗證集上的超參數。詞彙量爲50,000,批處理大小爲64。我們將嵌入大小設置爲512。編碼器和集解碼器都是2層的LSTM,隱藏大小爲512,但前者設置爲雙向。我們用MLE(極大似然估計)方法對模型進行了20個epoch的預訓練。優化器是Adam(Kingma和Ba, 2015)與10-3訓練的學習速率和10-5RL(強化學習)學習率。此外,我們使用dropout (Srivastava et al., 2014)來避免過度擬合,並剪切梯度(Pascanu et al., 2013)到最大範數8。

3.3 Baselines

我們將我們的方法與以下競爭性基線進行比較:

  • BR-LR:相當於爲每個標籤獨立訓練一個二元分類器(邏輯迴歸)。
  • PCC-LR:將MLC任務轉換爲二進制分類(邏輯迴歸)問題鏈。
  • FastXML:學習訓練實例的層次結構,並在層次結構的每個節點上優化目標。
  • XML-CNN:使用動態最大池機制和隱藏的瓶頸層來更好地表示文檔。
  • CNN-RNN:提出了一種CNN和RNN的集成方法來捕獲全局和局部文本語義。
  • Seq2Seq:採用Seq2Seq模型進行多標籤分類

3.4 Evaluation Metrics

評價指標包括:計算誤分率的子集0-1損失,表示誤預測標籤佔總標籤的比例的漢明損失,以及表示每個類的F1分的加權平均值的micro-F1。微精度和微召回也供參考。

四、Results and Discussion

本文對模型和實驗結果進行了深入分析。爲簡單起見,我們使用BR來表示基線BR- LR。

4.1 Experimental Results

在這裏插入圖片描述
我們的方法和所有基線的比較如表1所示,表明所提出的Seq2Set模型在所有評價指標上都比所有基線有較大的優勢。與完全忽略標籤相關性的BR相比,我們的Seq2Set模型減少了12.05%的漢明損失。結果表明,對高階標籤相關關係進行建模可以大大改善結果。與對標籤訂單有嚴格要求的Seq2Seq相比,我們的Seq2Set模型在RCV1-V2數據集上減少了3.95%的漢明損失。這表明我們的方法可以通過減少模型對標籤訂單的依賴來實現實質性的改進。

4.2 Reducing Sensitivity to Label Order

在這裏插入圖片描述
爲了驗證我們的方法可以降低標籤順序的敏感性,我們隨機打亂標籤序列的順序。表2展示了不同模型在標籤變換的RCV1-V2數據集上的性能。結果表明,對於打亂的標籤順序,BR沒有受到影響,但是Seq2Seq的性能卻急劇下降。因爲Seq2Seq的解碼器本質上是一個條件語言模型。它嚴重依賴於一個合理的標籤順序來建模標籤之間的內在關聯,而在這種情況下,標籤呈現無序狀態。然而,我們的模型在子集0 - 1損失上的性能僅下降了1.2%5,而Seq2Seq下降了9.3%。這說明我們的Seq2Set模型具有更強的魯棒性,可以抵抗標籤順序的干擾。我們的模型是通過強化學習來訓練的,獎勵反饋與標籤順序無關,降低了對標籤順序的敏感性。

4.3 Improving Model Universality

RCV1-V2數據集中的標籤呈現長尾分佈。然而,在實際場景中,還有其他常見的標籤分佈,如均勻分佈(Lin et al., 2018a)。因此,這裏我們分析了Seq2Set模型的通用性,這意味着它可以在不同的標籤分發情況下實現穩定的性能改進。詳細地,我們依次刪除RCV1-V2數據集中最頻繁的k標籤,並對其餘標籤執行評估。k越大,標籤分佈越均勻。圖1顯示了不同系統的性能變化。
在這裏插入圖片描述
首先,隨着移除高頻標籤的數量增加,所有方法的性能都會下降。這是合理的,因爲預測低頻標籤相對困難。但是,與其他方法相比,Seq2Seq模型的性能大大降低。我們懷疑這是因爲統一分佈的標籤很難定義一個合理的訂單,而Seq2Seq對標籤的訂單有嚴格的要求。這種衝突可能會損害性能。然而,如圖1所示,隨着更多的標籤被刪除,Seq2Set相對於Seq2Seq的優勢繼續增強。這說明我們的Seq2Set模型具有良好的通用性,適用於不同的標籤分發。我們的方法不僅具有Seq2Seq捕獲標籤相關性的能力,而且通過強化學習,減輕了Seq2Seq對標籤順序的嚴格要求。這樣就避免了在均勻分佈上預先定義合理的標籤順序的困難,從而具有很好的通用性。

4.4 Error Analysis

在這裏插入圖片描述
我們發現所有的方法在預測低頻(LF)標籤和高頻(HF)標籤時的表現都很差。這是合理的,因爲分配給LF標籤的樣本是稀疏的,使得模型很難學習有效的模式來進行預測。圖2爲不同方法對HF標籤和LF標籤的檢測結果。但是,與其他系統相比,我們提出的Seq2Set模型在LF標籤和HF標籤上都有更好的性能。此外,我們的方法在LF標籤上取得的相對改進要大於HF標籤。事實上,LF標籤的分佈較爲均勻。如4.3節所分析的,在均勻分佈中,標籤訂單問題更爲嚴重。我們的Seq2Set模型可以通過強化學習來減少對標籤順序的依賴,從而使LF標籤的性能有較大的提高。

五、Related Work

多標籤分類(MLC)旨在爲數據集中的每個樣本分配多個標籤。早期對MLC任務的研究主要集中在機器學習算法上,主要包括問題轉換方法和算法適應方法。問題轉換方法,如BR (Boutell et al., 2004)、LP (Tsoumakas和Katakis, 2006)和CC (Read et al., 2011),旨在將MLC任務映射成多個單標記學習任務。算法適應方法力求擴展特定的學習算法,直接處理多標籤數據。相應的代表作有ML-DT (Clare and King, 2001)、Rank-SVM (Elisseeff and Weston, 2001)、ML-KNN (Zhang and Zhou, 2007)等。此外,其他一些方法,包括集成方法(Tsoumakas et al., 2011)和聯合訓練(Li et al., 2015),也可以用於MLC任務。然而,它們只能用於捕獲一階或二階標籤相關性(Chen et al., 2017),或者在考慮高階標籤相關性時是計算上難以處理的。

近年來,一些神經網絡模型也被成功地用於MLC任務。例如,Zhang和Zhou(2006)提出的BP-MLL採用全連通網絡和兩兩排序損失進行分類。Nam等(2013)進一步用交叉熵損失函數代替兩兩排序損失。Kurata等人(2016)提出了一種利用神經元對標籤相關性進行建模的初始化方法。Chen等人(2017)提出了CNN和RNN的集成方法來捕獲全局和局部語義信息。Liu等人(2017)使用動態最大池機制和隱藏的瓶頸層來更好地表示文檔。Peng等人(2018)利用圖卷積運算來捕獲非連續和長距離語義。這兩個里程碑是Nam et al.(2017)和Yang et al.(2018),兩者都利用Seq2Seq模型來捕獲標籤相關性。更進一步,Lin等(2018b)提出了一種基於語義單元的擴展卷積模型,Zhao等(2018)提出了一種基於標籤圖的神經網絡,該神經網絡採用軟訓練機制來捕獲標籤相關性。最近,Qin等人(2019)提出了新的基於集合概率的訓練目標,有效地對集合的數學特徵進行建模。

六、Conclusion

在本研究中,我們提出一種簡單而有效的基於強化學習的序對集模型,其目的在於減少對標籤順序序對集模型的嚴格要求。該模型不僅捕獲了標籤之間的高階相關性,而且減少了對輸出標籤順序的依賴。實驗結果表明,我們的Seq2Set模型能夠大幅度地超越競爭基線。進一步的分析表明,我們的方法可以有效地降低標籤訂單的敏感性。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章