作者: Liyang
Data-Free Learning of Student Networks
論文地址:https://arxiv.org/pdf/1904.01186
開源地址:https://github.com/huawei-noah/DAFL
論文作者:Hanting Chen, Yunhe Wang, Chang Xu, Zhaohui Yang, Chuanjian Liu, Boxin Shi;Chunjing Xu, Chao Xu, Qi Tian(北京大學電子與計算機工程學院,華爲諾亞方舟實驗室等)
前言
本文將對ICCV2019會議論文《Data-Free Learning of Student Networks》進行解讀,這篇論文在神經網絡壓縮領域有相當高的實用價值。作者從難以獲取teacher網絡原始訓練集的角度出發,提出了一種將teacher網絡用作固定的判決器,利用GAN(Generative Adversarial Networks)的生成器來產生模擬原始訓練集的訓練樣本,進一步訓練、獲得具有較小尺寸和複雜度的student(portable)網絡。
實驗結果表明,作者所提出的DAFL(Data-Free Learning)方法在MNIST、CIFAR、CelebA等數據集上具有很好的性能,相對於KD(Knowledge Distillation)等方法具有更好的實用性。
背景
神經網絡壓縮算法目前根據有無原始數據的參與分爲兩種。
Data-Driven類
Hinton等提出了一種知識蒸餾方法(knowledge distillation,KD),該方法提煉出經過預訓練的teacher網絡的信息,以學習portable (student )網絡【Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. Distillingthe knowledge in a neural network. arXiv preprintarXiv:1503.02531, 2015】。Denton等利用低秩分解(SVD)來處理全連接層的權重矩陣【Emily L Denton, Wojciech Zaremba, Joan Bruna, Yann Le- Cun, and Rob Fergus. Exploiting linear structure within convolutional networks for efficient evaluation. In NIPS, 2014】。Han等採用修剪、量化和霍夫曼編碼來獲得緊湊的深度CNN,使之具有較低的計算複雜度【Song Han, Huizi Mao, and William J Dally. Deep compression:Compressing deep neural networks with pruning, trained quantization and huffman coding. arXiv preprint arXiv:1510.00149, 2015】。Li等進一步提出了一種特徵模擬框架,以訓練有效的卷積網絡進行目標檢測【Quanquan Li, Shengying Jin, and Junjie Yan. Mimicking very efficient network for object detection. In CVPR, pages 7341-7349. IEEE, 2017】。
上述方法在大多數據集上取得了良好的效果,但如果沒有原始訓練數據集,則很難應用。
Data- Free類
Lopes等利用原始訓練數據集記錄的 “元數據”(meta-data)(例如,每層激活的平均值和標準偏差),但大多數訓練過的CNN很難提供此數據【Raphael Gontijo Lopes, Stefano Fenu, and Thad Starner.Data-free knowledge distillation for deep neural networks.arXiv preprint arXiv:1710.07535, 2017】。Srinivas和Babu提出在完全連接的層中直接合並相似的神經元來壓縮網絡,但這很難應用於未詳細說明結構和參數信息的卷積層和網絡【Suraj Srinivas and R Venkatesh Babu. Data-free parameter pruning for deep neural networks. arXiv preprint arXiv:1507.06149, 2015】。
實際上,由於如涉及隱私、傳輸限制等因素,原始訓練數據集和詳細的網絡結構、參數等很難獲取,這就意味着上述兩類方法難於應用。
整體架構
GAN有一個生成器和一個判決器,給定的teacher網絡同時作爲GAN的判決器,不對其進行任何更新。Random Signals(隨機信號)輸入到GAN的生成器,變換爲模擬的原始數據,由判決器進行識別。生成器生成一組數據後,再通過KD方法對student網絡的參數進行更新。
Data-free Student Network learning
Teacher-Student關係
由於很難獲取原始訓練數據集,有時也無法獲得參數和詳細結構信息。作者從teacher-student學習範例來着手。作者認爲KD並未利用給定網絡的參數和體系結構的信息,儘管可能僅提供有限的接口(如輸入和輸出接口),但仍然可以從teacher網絡繼承一些有用的信息。令NT和NS表示teacher網絡和所需的portable/student網絡,可用以下基於知識提煉的損失函數來優化student網絡:
其中Hcross爲交叉熵損失,和分別是NT和NS的輸出。利用知識轉移技術,可在沒有給定網絡特定架構的情況下優化portable網絡。
用GAN生成訓練樣本
- one-hot損失函數
GAN由生成器G和判決器D組成,G用來生成數據,D用來識別真實圖像與G生成的圖像之間的差異。若獲得最優的G,必須先得到一個最佳判決器D*,而判決器一般需要真實的圖像進行訓練。作者認爲給定的深度神經網絡已經在大規模數據集上得到了很好的訓練,因此它也可以從圖像中提取語義特徵,因此不需額外的判決器,所以作者將之視爲固定的判別器。
另外,判決器的輸出是顯示輸入圖像在vanilla GANs中是真實還是僞造的概率。但是,考慮到將teacher網絡作爲判別器,其輸出是圖像分類。爲此作者設計了一些新的損失函數。
如給定一組隨機向量{},通過公式,生成圖像{}。將這些圖片輸入到teacher網絡,由公式,獲得輸出{ }。再根據公式,可獲得預測的標籤{}。如果G生成的圖像遵循與原始訓練數據相同的分佈,則它們應具有相似的輸出。因此,作者引入了one-hot損失,來使生成圖片的輸出接近teacher網絡one-hot向量。把{}作爲ground-truth標籤,作者設計的one-hot損失函數爲
其中,N_cross爲交叉熵損失函數。
- 特徵圖激活損失
作者認爲由卷積層提取的中間特徵也是輸入圖像的重要信息。將teacher網絡提取的xi的特徵表示爲f_Ti,它對應於全連接層之前的輸出。由於已對teacher中的過濾器進行了訓練以提取訓練數據中的固有模式,如果輸入圖像是真實的而不是一些隨機矢量,則特徵圖往往會收到更高的激活值。因此,作者將激活損失函數定義爲:
其中,‖·‖_1是常規的l1範數。換言之,如輸入圖像觸發了teacher網絡的某些特徵提取器,則意味與真實圖像相似,是真實圖片的概率較大。
- 圖像的信息熵損失函數
作者採用信息熵損失來衡量生成圖像的類平衡。給定一個概率向量 {},測量混亂度信息熵p的計算方式是。的值表示p擁有的信息量,當所有變量等於1時將取最大值。給定一組輸出向量{ },每個類生成的圖像的頻率分佈是1/n ∑_iy_T^i。因此,作者將生成圖像的信息熵損失定義爲
當損失最小時,1/n ∑_iy_S^i 向量中每個元素都等於1/k,意味着G可以以大致相同的概率生成每個類別的圖像。因此,最小化所生成圖像的信息熵可以導致合成圖像的平衡集合。
通過結合上述三個損失函數,得到最終的目標函數
α和β是平衡三個不同項的參數。作者認爲該方法可以直接模擬訓練數據的分佈,更加靈活,高效地生成新圖像。
優化
作者設計的學習過程可以分爲兩個訓練階段。首先,將teacher網絡作爲固定判決器。使用上述的LTotal損失函數,優化生成器G。其次,我們利用KD方法將知識直接從teacher網絡轉移到student網絡。使用KD的損失LKD來優化具有較少參數的student網絡。
實驗
MNIST實驗
作者分別使用基於卷積的結構和由全連接組成的網絡。前者,採用LeNet-5爲teacher網絡,使用LeNet-5-Hafl(修改後的版本爲每層通道數的一半)作爲student模型。後者,teacher網絡由兩個有1200個單元的隱藏層組成(Hinton-784-1200-1200-10),student網絡由兩個有800個單元的隱藏層組成(Hinton-784-800-800-10)。表1爲MNIST實驗結果。
作者首先採用隨機生成的正態分佈數據來作爲訓練集時,student網絡準確度僅有88.01%,再用USPS數據集訓練student網絡準確度也僅達到94.56%,說明原始訓練集的效果是其他數據集難以替代。作者提出的DAFL方法達到了98.20%的準確度。
在全連接結構上,與前者結果類似,DAFL方法達到了97.91%的準確度。
消融實驗
作者用隨機生成的樣本,student網絡的準確度只有88.01%。僅用one-hot損失或特徵圖激活損失,所產生的樣本不均衡,會導致較差。將Loh或La與Lie結合時,student網絡分別爲97.25%和95.53%。一起使用時,可達到最佳性能。
可視化結果實驗
說明生成器確實以某種方式在學習原始數據分佈。
第一卷積層中過濾器的可視化結果如下所示,表明student網絡從teacher網絡獲取了某些有價值的知識。
CIFAR數據集實驗
作者進一步使用ResNet-34作爲teacher網絡,ResNet-18作爲student網絡。
在CIFAR-10上,KD的student網絡爲94.34%的準確度,DAFL方法可達到92.22%準確度,進一步說明DAFL方法可更好地模擬原始數據。
CIFAR數據集CelebA
作者將AlexNet網絡設置爲teacher網絡,student網絡設置爲AlexNet-Hafl(修改後的版本爲每層通道數的一半)。
DAFL的student網絡的準確度爲80.03%,與teacher網絡的準確度幾乎相當。
擴展實驗
作者使用生成的圖像訓練與teacher網絡具有相同架構的student網絡。重新訓練的student網絡的準確度與teacher網絡的準確度非常接近。
總結
神經網絡壓縮算法通常需要原始訓練數據。但在實際中,由於某些隱私和傳輸限制,要獲取給定網絡的訓練數據和詳細網絡結構信息並非易事。在本文中,作者提出了一個新穎的框架來訓練生成器,以模擬原始訓練數據集。然後通過KD方法有效地學習、得到portable網絡。在MNIST、CIFAR等數據集上的實驗表明,作者所提出的DAFL方法可獲得性能較好的portable網絡,顯示了一定的實用價值。