讀論文:Data-Free Learning of Student Networks

    ICCV 2019年總共有三篇文章講了模型壓縮與加速,這是其中一篇。文章9月份更新了最新版,網上關於文章的解讀也有了不少,我在此篇博客中簡單講下自己的理解。

1 文章解決的問題

    這篇文章解決了下圖架構的一類問題,即:

  1. 已經有了大型複雜的網絡處理分類問題
  2. 沒有訓練數據;
  3. 想要獲得小型的、能夠部署、性能接近大型網絡的分類模型。

           

2 如何做的

2.1 預備知識

    想要看懂這篇文章,首先得了解知識蒸餾的內容。知識蒸餾問題的架構如下圖所示。與關於知識蒸餾的解讀可以查閱我的上一篇博文知識蒸餾(Distilling Knowledge )的核心思想 。下圖與上圖的區別在於數據與老師網絡是否已知。在數據未知的情況下,作者採用 GAN 網絡,利用噪聲圖像數據,生成了訓練數據。從而使得學生網絡基於訓練數據不斷優化自身參數,學習老師網絡的知識。

          

2.2 核心點

2.2.1 整體架構

    文章的核心點在於如何生成用於訓練的數據。GAN 理論的核心點在於納什均衡理論。訓練 GAN 網絡需要真實數據和標籤。老師網絡作爲 GAN 網絡中的判別器,無須進行訓練,所以只要訓練一個生成器即可。於是作者設計了一套損失函數,幫助均勻地生成各個類的訓練數據,爲什麼說均勻的呢,下面會進一步介紹。

   整體的網絡架構如下圖所示。

2.2.2 損失函數

    訓練生成器的損失函數:

    爲了訓練生成器,作者設計了一個整合三種損失函數的loss。

    噪聲 z 經過生成器生成圖像 x,經過 TN 網絡得到輸出向量 y,利用 arg max 得到標籤 t (這裏 t 的定義在於 y 的輸出值,比方 y 是 [0, 0, 1](總有一個最大的值),那麼該數據的標籤就是1(最大的值)對應的那類。

  • Loh 是 one-hot 損失函數,採用交叉熵損失函數。通俗的講,就是隨機信號第一次經過 TN 的輸出 y 是[0.2, 0.3, 0.5], 那麼標籤就是第三類,訓練的過程中讓 y 不斷靠近 [0, 0, 1]。
  • La 是激活函數,f 是TN 所提取的最終特徵,爲了使得更多的最終特徵成爲能夠模擬近似真實輸入圖片的特徵(不稀疏)來輸出,用了加負號的L1正則進行約束(L1正則本身產生稀疏向量)。
  • Lie 是信息熵損失函數,當所有類別均衡時,Hinfo取得最大值,增加負號取到最小值。因此該損失函數爲了獲得均衡分佈的數據。

    訓練 SN 的損失函數:

    同知識蒸餾中的交叉熵損失函數。

2.2.3 訓練流程

    整體訓練流程如下圖所示,簡單講就是先訓練生成器,再訓練 SN。

                                             

3 其他

    論文中說明了損失函數的可導性、有效性。實驗的結果可查看原文。我比較好奇特徵圖激活損失函數的有效性,論文中表格2內的 TOP 1 accuracy 中,特徵圖激活函數對於準確性的提升十分有限。畢竟最後輸出的 TN 輸出的 y 是稀疏的,那麼 f 非稀疏的必要性,可能是有待探討的。

 

 
 

 

 

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章