原文標題是Variational Information Distillation for Knowledge Transfer,是CVPR2019的錄用paper。
VID方法
思路比較簡單,就是利用互信息(mutual information,MI)的角度,增加teacher網絡與student網絡中間層特徵的MI,motivation是因爲MI可以表示兩個變量的依賴程度,MI越大,表明兩者的輸出越相關。
首先定義輸入數據x∼p(x),給定一個樣本x,得到關於teacher和student輸出的K個對集合R={(t(k),s(k))}k=1K,K表示選擇的層數。變量對的MI被定義爲I(t;s)=H(t)−H(t∣s)=−Et[logp(t)]+Et,s[logp(t∣s)]
之後可以設計如下的loss函數來增大teacher和student之間的輸出特徵的互信息:
L=LS−k=1∑KλkI(t(k),s(k))
其中LS表示task-specific的誤差,λk是超參數用於平衡誤差。因爲精確的計算MI是困難的,這裏採用了變分下界(variational lower bound)的trick,採用variational的思想使用一個variational分佈q(t∣s)去近似真實分佈p(t∣s)。
Note that variational的思想就是針對某個分佈很難求解的時候,採用另外一個分佈來近似這個分佈的做法,並使用變分信息最大化 (論文:The IM algorithm: A variational approach to information maximization) 的方法求解變分下界(variational low bound),這方法也被用在InfoGAN中。
I(t;s)=H(t)−H(t∣s)=H(t)+Et,s[logp(t∣s)]=H(t)+Et,s[logq(t∣s)]+Es[DKL(p(t∣s)∣∣q(t∣s))]≥H(t)+Et,s[logq(t∣s)]
Et,s[logp(t∣s)]=Et,s[logq(t∣s)]+Es[DKL(p(t∣s)∣∣q(t∣s))]這個關係是由變分信息最大化中得到的,真實分佈logp(t∣s)的期望等於變分分佈Et,s[logq(t∣s)]的期望+兩分佈的KL散度期望。因爲KL散度的值是恆大於0的,所以得到變分下界。進一步可以得到如下的誤差函數:
L~=LS−k=1∑KλkEt(k),s(k)[logq(t(k)∣s(k))]
H(t)由於和待優化的student參數無關,所以是常數。聯合的訓練學生網絡利用target task和最大化條件似然去擬合teacher激活值。
作者採用高斯分佈來實例化變分分佈,這裏的採用heteroscedastic的均值μ(⋅),即μ(⋅)是關於student輸出的函數;同時採用homoscedastic的方差σ,即不是關於student輸出的函數,作者嘗試採用heteroscedastic的均值σ(⋅),但是容易訓練不穩定且提升不大。μ(⋅)其實就是相當於在feature KD時teacher與student之間的迴歸器,包含卷積等操作。
−logq(t∣s)=−c=1∑Ch=1∑Hw=1∑Wlogq(tc,h,w∣s)=c=1∑Ch=1∑Hw=1∑Wlogσc+2σc2(tc,h,w−μc,h,w(s))2+constant
由σc=log(1+exp(αc)),αc是一個可學習的參數。
對於logit層,−logq(t∣s)=−n=1∑Nlogq(tn∣s)=n=1∑Nlogσn+2σn2(tn−μn(s))2+constant
這裏μ(⋅)是一個線性的變換矩陣。
與MSE的區別
作者認爲當前基於MSE的方法是該方法在方差相同時的特例,即爲:
−logq(t∣s)=n=1∑N2(tn−μn(s))2+constant
VID比MSE的好處爲建模了不同維度的方差,使得更加靈活的方式來避免一些model capacity用來到一些無用的信息。MSE採用一樣的方差會高度限制student,如果teacher的無用信息也同樣的地位擬合,會造成過擬合問題,浪費掉了student的網絡capacity。