ICCV 2019年總共有三篇文章講了模型壓縮與加速,這是其中一篇。文章9月份更新了最新版,網上關於文章的解讀也有了不少,我在此篇博客中簡單講下自己的理解。
1 文章解決的問題
這篇文章解決了下圖架構的一類問題,即:
- 已經有了大型複雜的網絡處理分類問題;
- 沒有訓練數據;
- 想要獲得小型的、能夠部署、性能接近大型網絡的分類模型。
2 如何做的
2.1 預備知識
想要看懂這篇文章,首先得了解知識蒸餾的內容。知識蒸餾問題的架構如下圖所示。與關於知識蒸餾的解讀可以查閱我的上一篇博文知識蒸餾(Distilling Knowledge )的核心思想 。下圖與上圖的區別在於數據與老師網絡是否已知。在數據未知的情況下,作者採用 GAN 網絡,利用噪聲圖像數據,生成了訓練數據。從而使得學生網絡基於訓練數據不斷優化自身參數,學習老師網絡的知識。
2.2 核心點
2.2.1 整體架構
文章的核心點在於如何生成用於訓練的數據。GAN 理論的核心點在於納什均衡理論。訓練 GAN 網絡需要真實數據和標籤。老師網絡作爲 GAN 網絡中的判別器,無須進行訓練,所以只要訓練一個生成器即可。於是作者設計了一套損失函數,幫助均勻地生成各個類的訓練數據,爲什麼說均勻的呢,下面會進一步介紹。
整體的網絡架構如下圖所示。
2.2.2 損失函數
訓練生成器的損失函數:
爲了訓練生成器,作者設計了一個整合三種損失函數的loss。
噪聲 z 經過生成器生成圖像 x,經過 TN 網絡得到輸出向量 y,利用 arg max 得到標籤 t (這裏 t 的定義在於 y 的輸出值,比方 y 是 [0, 0, 1](總有一個最大的值),那麼該數據的標籤就是1(最大的值)對應的那類。
- Loh 是 one-hot 損失函數,採用交叉熵損失函數。通俗的講,就是隨機信號第一次經過 TN 的輸出 y 是[0.2, 0.3, 0.5], 那麼標籤就是第三類,訓練的過程中讓 y 不斷靠近 [0, 0, 1]。
- La 是激活函數,f 是TN 所提取的最終特徵,爲了使得更多的最終特徵成爲能夠模擬近似真實輸入圖片的特徵(不稀疏)來輸出,用了加負號的L1正則進行約束(L1正則本身產生稀疏向量)。
- Lie 是信息熵損失函數,當所有類別均衡時,Hinfo取得最大值,增加負號取到最小值。因此該損失函數爲了獲得均衡分佈的數據。
訓練 SN 的損失函數:
同知識蒸餾中的交叉熵損失函數。
2.2.3 訓練流程
整體訓練流程如下圖所示,簡單講就是先訓練生成器,再訓練 SN。
3 其他
論文中說明了損失函數的可導性、有效性。實驗的結果可查看原文。我比較好奇特徵圖激活損失函數的有效性,論文中表格2內的 TOP 1 accuracy 中,特徵圖激活函數對於準確性的提升十分有限。畢竟最後輸出的 TN 輸出的 y 是稀疏的,那麼 f 非稀疏的必要性,可能是有待探討的。