蒸餾網絡 Distilling the Knowledge in a Neural Network

論文:https://arxiv.org/abs/1503.02531

0 摘要

提高几乎所有機器學習算法性能的一種非常簡單的方法是在相同的數據上訓練許多不同的模型,然後對它們的預測進行平均[3]。不幸的是,使用整個模型集合進行預測是麻煩的,並且可能在計算上太昂貴而不能允許對大量用戶進行部署,尤其是如果各個模型是大型神經網絡。Caruana和他的合作者[1]已經證明,有可能將整體中的知識壓縮成一個更容易部署的單一模型,並且我們使用不同的壓縮技術進一步開發了這種方法。我們在MNIST上取得了一些令人驚訝的結果,我們表明,通過將模型集合中的知識提煉到單個模型中,我們可以顯著改善大量使用的商業系統的聲學模型。我們還介紹了一種由一個或多個完整模型和許多專業模型組成的新型集合,它們學會區分完整模型混淆的細粒度類。與專家混合不同,這些專業模型可以快速並行地進行訓練。

1 簡介

許多昆蟲具有幼蟲形態,其被優化用於從環境中提取能量和營養物,以及完全不同的成體形式,其針對非常不同的旅行和繁殖要求進行了優化。在大規模機器學習中,我們通常在訓練階段和部署階段使用非常相似的模型,但其實我們對它們的要求非常不同:對於語音和對象識別等任務,訓練必須從非常大的、高度冗餘的數據集中提取結構,但不會需要實時操作,訓練過程可以使用大量的計算。但是,模式部署時對延遲和計算資源有更嚴格的要求。與昆蟲的類比表明,如果能夠更容易地從數據中提取結構,我們應該願意訓練非常繁瑣的模型。繁瑣的模型可以是單獨訓練的多個模型的集合,也可以是使用非常強的正則化器(例如dropout)訓練的單個非常大的模型[9]。一旦繁瑣的模型訓練完成,我們可以再使用稱之爲“蒸餾”的訓練過程,將知識從繁瑣的模型轉移到更適合部署的小模型。Rich Caruana及其合作者已經開創了這一戰略的一個版本[1]。在他們的重要論文中,他們令人信服地證明,大的模型集合所獲得的知識可以轉移到一個小型模型中。

可能阻礙對這種非常有前景的方法進行更多概念性研究的問題是我們傾向於使用學習的參數值來識別訓練模型中的知識,這使我們很難看到我們如何改變模型的形式但同時保持相同的知識。關於知識的更抽象的觀點,使其從任何特定的實例中解放出來,就是它是一種學習的從輸入向量到輸出向量的映射。對於學會區分大量類別的繁瑣模型,正常的訓練目標是最大化正確答案的平均對數概率,但學習的副作用是訓練的模型將概率分配給所有不正確的答案。即使這些概率非常小,但其中總有一些比其他概率大得多。不正確答案的相對概率告訴我們很多關於繁瑣模型如何概括的知識。例如,寶馬的圖像可能只有很小的機會被誤認爲是垃圾車,但這個錯誤仍然比將它被誤認爲是胡蘿蔔的可能性高很多倍。

人們普遍認爲,用於訓練的目標函數應儘可能地反映用戶的真實目標。在訓練過程中,人們往往以最優化訓練集的準確率作爲訓練目標,但真實目標卻是最優化模型的泛化能力。顯然如果能以提升模型的泛化能力爲目標進行訓練是最好的,但這需要正確的關於泛化能力的信息,而這些信息通常不可用。當我們將大型模型的知識蒸餾進小型模型時,我們可以訓練小模型以與大型模型相同的方式進行泛化。如果繁瑣的模型是因爲它是很多不同模型的集成才取得了很好地泛化,那麼在同一數據集上,以與繁瑣模型相同泛化方式訓練的小模型通常比以正常方式訓練的小模型具有更好的測試性能。

將繁瑣模型的泛化能力轉移到小模型的一種顯而易見的方法是使用由繁瑣模型產生的類概率作爲訓練小模型的“軟目標”。對於此遷移階段,我們可以使用相同的訓練集或單獨的“轉移”集。當繁瑣的模型是較大的簡單模型的集合時,我們可以使用其各自預測分佈的算術或幾何平均值作爲軟目標。當軟目標具有高熵值時,它們爲每個訓練案例提供比硬目標更多的信息,並且在訓練案例之間梯度的變化更小,因此小模型通常可以在比原始繁瑣模型少得多的數據上訓練並可以使用更高的學習率加快訓練過程。

對於像MNIST這樣的任務,繁瑣的模型幾乎總是以非常高的置信度產生正確答案,關於學習函數的許多信息存在於軟目標中非常小的概率的概率值中。例如,一個真實值爲2的一個版本的圖像可以以10-6的概率被識別爲3和以10-9的概率被識別爲7,而對於另一個版本的實際爲2的圖像,它可以是另一種識別結果。這是有價值的信息,它定義了數據上豐富的相似性結構(即它表示哪個2看起來像3,哪個看起來像7),但它在遷移階段對交叉熵損失函數的影響非常小,因爲概率都是如此接近於零。Caruana和他的合作者通過使用logits(最終softmax的輸入)而不是softmax作爲學習小模型的目標產生的概率來解決這個問題,並且他們最小化了由繁瑣模型和小模型產生的logits之間的平方差異。我們更通用的解決方案,稱爲“蒸餾”,是提高最終softmax的溫度,直到繁瑣的模型產生適當柔軟的目標。然後,我們在訓練小模型時使用相同的高溫來匹配這些軟目標。後面將表明,匹配繁瑣模型的logits實際上是一個特殊的蒸餾情況。

用於訓練小模型的遷移集可以完全由未標記的數據組成[1],或者是原始的訓練集。我們發現使用原始訓練集效果很好,特別是如果我們在目標函數中添加一個小項,鼓勵小模型預測真實目標同時也匹配繁瑣模型提供的軟目標。通常,小模型不能與軟目標完全匹配,並且在正確答案的方向上犯錯是有幫助的。

2 蒸餾

神經網絡通常通過使用“softmax”輸出層來產生類概率,該輸出層通過將ziz_i與其他logit進行比較,將針對每個類計算的logit ziz_i轉換爲概率qiq_i
qi=exp(zi/T)jexp(zj/T)q_i = \frac{\exp(z_i/T)}{\sum_j \exp(z_j/T)}

其中T是通常設置爲1的溫度。使用較高的T值可以在類上產生更軟的概率分佈。

在最簡單的蒸餾形式中,通過在遷移集上訓練並將遷移集中的每個類別使用軟目標分佈表示,將知識轉移到蒸餾的小模型,軟目標分佈通過使用在softmax中具有高溫的繁瑣模型產生。訓練蒸餾模型時使用相同的高溫,但訓練完成後,溫度設置爲1

當已知所有或部分遷移集的正確標籤時,通過訓練蒸餾模型以產生正確的標籤,可以顯著改善該方法。一種方法是使用正確的標籤來修改軟目標,但我們發現更好的方法是簡單地使用兩個不同目標函數的加權平均值。第一個目標函數是具有軟目標的交叉熵,該交叉熵是從蒸餾模型的具有高溫的softmax函數中獲得的,該高溫值與繁重模型生成軟目標交叉熵時的溫度值相同。第二個目標函數是具有正確標籤的交叉熵。這是使用與蒸餾模型的softmax中完全相同的logits計算的,但溫度值爲1。我們發現通常通過在第二目標函數上使用可忽略不計的較低權重來獲得最佳結果。由於軟目標產生的梯度的大小爲1/T21 / T^2,因此當使用硬目標和軟目標時,將軟目標乘以T2T^2是很重要的。這確保瞭如果在試驗元參數時改變了用於蒸餾的溫度,則硬和軟目標的相對貢獻保持大致不變。

2.1 匹配logits是蒸餾的特例

遷移集中的每個例子相對於蒸餾模型的每個logit - ziz_i貢獻交叉熵梯度dC/dzidC / dz_i。如果繁瑣的模型具有logit viv_i,其產生軟目標概率pip_i並且遷移訓練在溫度爲T下完成,則該梯度由下式給出:

Czi=1T(qipi)=1T(ezi/Tjezj/Tevi/Tjevi/T)\frac{\partial C}{\partial z_i} = \frac{1}{T}(q_i - p_i)=\frac{1}{T}(\frac{e^{z_i/T}}{\sum_je^{z_j/T}}-\frac{e^{v_i/T}}{\sum_j e^{v_i/T}})

如果溫度高於logits的幅度,我們可以近似:

Czi1T(1+zi/TN+jzj/T1+vi/TN+jvi/T)\frac{\partial C}{\partial z_i} \approx \frac{1}{T}(\frac{1+z_i/T}{N + \sum_j{z_j/T}}-\frac{{1+v_i/T}}{N+\sum_j {v_i/T}})

如果我們現在假設對於每個遷移樣本,logits均值爲零,那麼jzj=jvj\sum_j z_j = \sum_j v_j = 0。上式可以簡化爲:
Czi1NT2(zivi)\frac{\partial C}{\partial z_i} \approx \frac{1}{NT^2}(z_i - v_i)

因此,在高溫極限下,蒸餾相當於最小化1/2(zivi)21/2(z_i-v_i)^2,條件是對於每個遷移樣本提供零均值logits。在較低的溫度下,蒸餾對匹配比平均值更負的標記的關注要少得多。這是有利的,因爲這些對數幾乎完全不受用於訓練繁重模型的代價函數的限制,因此它們可能非常嘈雜。另一方面,非常負的logits可以傳達關於由繁重模型獲取的知識的有用信息。這些影響中的哪一個主導是一個經驗問題。我們發現,當蒸餾模型太小而無法捕獲繁瑣模型中的所有知識時,中間溫度效果最好,這強烈暗示忽略大的負logit可能會有所幫助。

3 在MNIST上的初步試驗

爲了瞭解蒸餾的工作情況,我們在所有60,000個訓練案例中訓練了一個單個大型神經網絡,其中有兩個隱藏的1200個整流線性隱藏單元層。如[5]中所述,使用dropout和權重約束強烈地對網絡進行正則化。dropout可以被視爲一種訓練共享權重的指數個大型模型集合的方法。另外,輸入的圖像在任何方向上最多有兩個像素抖動,該網絡具有67個測試誤差。而較小的具有兩層各800個整流線性隱藏單元的網絡,沒有使用正則化則實現了146個錯誤。但是,如果較小的網絡僅僅通過在20的溫度下增加匹配大網絡產生的軟目標的附加任務來進行正則化,則實現了74個測試誤差。這表明軟目標可以將大量知識傳遞給蒸餾模型,包括從遷移訓練數據中學習如何泛化的知識,即使遷移集不包含任何變換。

當蒸餾網絡在其兩個隱藏層中的每一箇中具有300個或更多個單元時,所有高於8的溫度給出相當類似的結果。但是當每層急劇減少到30個單位時,2.5到4範圍內的溫度比高溫或更低溫度的溫度要好得多。

然後我們嘗試刪除轉移集中的數字3的所有示例。因此,從蒸餾模型的角度來看,3是一個從未見過的神話數字。儘管如此,蒸餾模型僅產生206個測試錯誤,其中133個是1010個測試樣本中的3。大多數錯誤是由於對3的學習偏差太低而引起的。如果這個偏差增加3.5(這可以優化測試集的整體性能),蒸餾模型會產生109個錯誤,其中14個是3。因此,在正確的偏置下,儘管在訓練期間從未見過3,但蒸餾模型仍能獲得98.6%的針對3的測試準確率。如果轉移集僅包含訓練集中的7和8,則蒸餾模型的測試誤差爲47.3%,但當7和8的偏差減少7.6以優化測試性能時,測試誤差降至13.2%。

5 在非常大的數據集上訓練專家模型集

訓練一個模型集合是利用並行計算的一種非常簡單的方法,通常的反對意見是模型集合在測試時需要太多的計算,這個問題可以通過蒸餾來處理。對於模型集成還有另一個重要的反對意見:如果單個模型是大型神經網絡,並且數據集非常大,那麼即使很容易進行並行化處理,在訓練時所需的計算量也太多了。

在這一節中,我們給出了這樣一個數據集的例子,我們展示了學習專家模型,每個模型都集中在不同的混淆的類子集上,從而減少學習模型集合所需的總計算量。專注於對細微差別進行判斷的專家模型的主要問題是他們很容易過擬合,我們描述瞭如何通過使用軟目標來防止這種過擬合。

5.1 JFT數據集

JFT是一個谷歌內部的數據集,它有1億張帶有15000個標籤的圖像。當我們做本文研究時,谷歌的JFT基線模型是一個深度卷積神經網絡[7],它使用異步隨機梯度下降在大量核心上進行了大約6個月的訓練。訓練使用了兩種類型的並行化。首先,使用了數據並行,神經網絡在不同的核心上運行,並處理訓練集中不同的小批量數據。每個模型副本計算其當前小批量數據的平均梯度,並將此梯度發送到共享的參數服務器,該服務器將返回參數的新值。這些新值反映了參數服務器自上次向所有模型副本發送參數以來接收到的所有梯度。第二,使用了模型並行,每個模型副本通過在每個核心上放置不同的神經元子集而分佈在多個核心上。集成學習是第三種並行,但需要大量的計算核心。等待數年來訓練一組模型不是一個良好的選擇,因此我們需要一種更快的方法來改進基線模型。

5.2 專家模型

當類的數量非常大時,繁瑣的模型是一個集合,其中包含一個在所有數據上訓練的通用模型和許多“專家”模型,每個專家模型都在從一個非常容易混淆的子類(如不同類型的蘑菇)上進行訓練。通過將所有它不關心的類別組合成一個垃圾箱類別,可以使這類專家的Softmax更小。

爲了減少過擬合,共享學習低級特徵檢測器的工作,每個專家模型都是用通用模型的權重初始化的。然後通過訓練專家模型對這些權重進行輕微修改,訓練集一半來自於其特殊子集,另一半則通過隨機從訓練集的其餘部分進行抽樣獲得。訓練結束後,我們可以通過將垃圾箱類的logit增加到專家類被過採樣的概率的對數,從而糾正偏差訓練集。

5.3 將類賦予專家模型

爲了獲得專家對象類別的分組,我們決定關注我們的完整網絡經常混淆的類別。儘管我們可以計算混淆矩陣並將其用作查找此類聚類的方法,但我們選擇了一種更簡單的方法,它不需要真正的標籤來構建聚類。

特別是,我們將聚類算法應用於我們通用模型預測的協方差矩陣,因此,通常一起預測的一組類SmS^m將被用作我們的第m個專家模型的目標。我們將K-means算法的在線版本應用於協方差矩陣的列,並獲得了合理的聚類(如表2所示)。我們嘗試了幾種產生不同的聚類算法,都產生了類似的結果。

5.4 使用專家模型的集成進行推理

在調查專家模型被蒸餾後會發生什麼之前,我們想看看專家模型的集成表現如何。除了專家模型,我們總是有一個通用模型,以便我們可以處理我們沒有專家模型的類別,進而決定使用哪些專家模型。給定輸入圖像x,我們分兩步進行第一類的分類:

  • 第一步,對每一個測試樣本,我們根據通用模型找出前n個最可能的類別,稱之爲類別k。在我們的實驗中,我們設定n=1;
  • 第二步,我們實驗所有的專家模型m,針對混淆類SmS^m,其與k有一個非空交集,並將其稱爲活動的專家組AkA_k(請注意,此集合可能爲空)。然後,我們在所有的類上找到完整的概率分佈q,即最小化:
    KL(pg,q)+mAkKL(pm,q)KL(p^g,q)+\sum_{m \in A_k}KL(p^m,q)
    KL表示KL散度,pmpgp^mp^g表示專家模型和通用模型產生的概率分佈。分佈pmp^m是m的所有專家類加上一個垃圾箱類的分佈,因此當從完全q分佈計算其KL散度時,我們將完整q分佈分配給m的垃圾箱中的所有類的所有概率相加。

5.5 結論

在這裏插入圖片描述
從訓練好的基線全網絡開始,專家模型訓練速度非常快(JFT的幾天而不是幾周)。此外,所有專家模型都完全獨立地被訓練。表3顯示了基線系統和基線系統與專家模型相結合的絕對測試精度。使用61個專業模型,整體測試精度相對提高4.4%。我們還報告了條件測試準確性,這是僅考慮屬於專家類的示例的準確性,並將我們的預測限制在該類子集中。

對於我們的JFT專家實驗,我們訓練了61個專家模型,每個模型有300個類別(加上垃圾箱課程)。因爲專家模型的類集是相交的,所以我們經常有多個專家模型來覆蓋特定的圖像類。表4顯示了測試集示例的數量,使用專家模型時在位置1處正確的示例數量的變化,以及按專家數量細分的JFT數據集的top1準確度的相對百分比改進。當我們有更多的專家來覆蓋特定的類別時,我們對總體趨勢感到鼓舞,因爲訓練獨立的專家模型非常容易並行化,因此準確性的提高會更大。

在這裏插入圖片描述

6 使用軟目標進行正則化

我們關於使用軟目標而不是硬目標的主要主張之一是,可以在軟目標中攜帶許多有用的信息,這些信息不可能用單個硬目標編碼。在本節中,我們通過使用少得多的數據來擬合前面描述的基線語音模型的85M參數來證明這是一個非常大的影響。表5顯示只有3%的數據(大約20M的例子),用硬目標訓練基線模型會導致嚴重的過度擬合(我們進行了提前停止,因爲準確度在達到44.5%後急劇下降),而使用軟目標訓練的同一模型能夠恢復完整訓練集中的幾乎所有信息(約2%下降)。更值得注意的是,我們不必提前停止:具有軟目標的系統簡單地“收斂”到57%。這表明軟目標是將由所有數據訓練的模型發現的規律傳達給另一個模型的非常有效的方式。

6.1 使用軟目標防止專家模型過擬合

我們在JFT數據集實驗中使用的專家模型將所有非專家類摺疊成一個垃圾箱類。如果我們允許專家系統在所有類別上擁有完整的softmax,那麼可能有一種更好的方法來防止它們過度擬合而不是使用早期停止。專家模型接受有關其特殊類別高度豐富的數據的訓練。這意味着它的訓練集的有效大小要小得多,並且很容易過度擬合它對應的特殊類別。這個問題不能通過讓專家變得更小來解決,因爲我們失去了從建模所有非專家類別中獲得的非常有用的轉移效果。

我們使用3%的語音數據的實驗強有力地表明,如果專家模型用通用的權重進行初始化,我們可以通過使用非特殊類的軟目標訓練它來保留關於非特殊類的幾乎所有知識而不是用硬目標訓練它之外的特殊類別。軟目標可以由通用模型提供。我們目前正在探索這種方法。

7 與專家混合的關係

使用經過數據子集訓練的專家與專家混合[6]有一些相似之處,後者使用門控網絡來計算爲每個專家分配每個例子的概率。在專家模型學習處理分配給他們的示例的同時,門控網絡正在學習根據該示例的專家的相對判別性能來選擇將每個示例分配給哪些專家。使用專家的判別性能來確定學習的分配比簡單地聚類輸入向量併爲每個羣集分配專家要好得多,但這使得訓練難以並行化:首先,每個專家的加權訓練集的變化依賴於所有其他專家;其次,門控網絡需要比較同一個例子中不同專家的表現,以瞭解如何修改其分配概率。這些困難意味着專家的混合很少用於可能最有益的制度:具有包含明顯不同子集的大型數據集的任務。

並行化多個專家的訓練要容易得多。我們首先訓練一個通用模型,然後使用混淆矩陣來定義專家訓練的子集。一旦定義了這些子集,就可以完全獨立地訓練專家模型。在測試時,我們可以使用通用模型中的預測來確定哪些專家是相關的,只需要運行這些專家。

8 討論

我們已經證明,蒸餾可以很好地將知識從集合或大型高度正規化模型轉移到較小的蒸餾模型中。在MNIST蒸餾工作非常好,即使用於訓練蒸餾模型的轉移裝置缺少一個或多個類的任何實例。對於Android聲音搜索使用的深度聲學模型,我們已經證明通過訓練深度神經網絡集合所實現的幾乎所有改進都可以被提煉成相同大小的單個神經網絡。該模型部署起來要容易得多。

對於非常大的神經網絡來說,即使是訓練一個完整的整體也是不可行的,但我們已經證明,通過學習大量的數據,可以顯著提高已經訓練了很長時間的單個真正大網絡的性能。專家網絡,每個網都學會在高度混亂的集羣中區分各類。我們還沒有表明我們可以將專家的知識提煉回單個大網絡。

以下內容摘抄於:https://blog.csdn.net/shuzfan/article/details/53839372

—————————— 背景介紹 ——————————

大家都想要得到一個又好又快的模型,但是實際情況往往是模型越小則性能越差。文獻[1]中提出了一種策略:大模型學習到的知識可以通過“提取”的方法轉移到一個小模型上

所以,本文的宏觀策略就是先訓練一個性能很好的大模型,然後再據此學習得到一個性能不差的小模型。

—————————— 方法介紹 ——————————

因爲整體方法比較簡單,涉及到的推導和證明也比較少,所以這裏直接給出方法步驟:

(1)訓練一個大模型

使用普通的SoftmaxWithLoss訓練一個性能好但比較複雜的網絡,記作A
(PS. 該方法只適用於softmax分類)

(2)獲取soft target label

將A中的softmax替換爲下面的新的soft_softmax, 然後將訓練集跑一遍,並將每張訓練圖片的網絡輸出結果保存下來。

float

(1)式中T=1時,退化成傳統的softmax,T無窮大時,結果趨近於1/C,即所有類別上的概率趨近於相等。T>1時,我們就能獲得soft target label。

這裏解釋兩點:

——爲什麼稱之爲soft target label?

舉一個栗子,假如我們分三類,然後網絡最後的輸出是[1.0 2.0 3.0],我們可以很容易的計算出,傳統的softmax(即T=1)對此進行處理後得到的概率爲[0.09 0.24 0.67],而當T=4的時候,得到的概率則爲[0.25 0.33 0.42]。
可以看出,當T變大的時候輸出的概率分佈變得平緩了,這就稱之爲soft。

PS.不知道怎麼算概率的,可以看下我的另外一篇博客caffe層解讀系列-softmax_loss

——爲什麼要使用 soft target label?

通常我們使用softmax進行分類的時候,我們的label都是one shot label,比如我們分三類:貓、虎和豬,那麼一張貓的圖片它的label就是[1 0 0]。這種標註方式意味着每一類之間都是獨立的,完全沒有任何聯繫。但是事實上,貓和虎的相似度應該高於貓和豬的相似度,這種豐富的結構信息,one shot label(hard target)描述不了。此外,正因爲one shot label的hard,導致了我們學習得到的概率分佈也相對hard。顯然對於一張貓的圖片強行學習一個[1 0 0]的分佈,其難度要比學一個[0.65 0.3 0.05]的分佈大得多。

(3)訓練一個小網絡

重新創建一個小的網絡,該網絡最後有兩個loss,一個是hard loss,即傳統的softmaxloss,使用one shot label;另外一個是soft loss,即T>1的softmaxloss,使用我們第二步保存下來的soft target label。

整體的loss如下式:

—————————— 參考文獻 ——————————

《Buciluǎ C, Caruana R, Niculescu-Mizil A. Model compression[C]//Proceedings of the 12th ACM SIGKDD international conference on Knowledge discovery and data mining. ACM, 2006: 535-541.》https://www.cs.cornell.edu/~caruana/compression.kdd06.pdf

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章