變分互信息蒸餾(Variational mutual information KD)

原文標題是Variational Information Distillation for Knowledge Transfer,是CVPR2019的錄用paper。

VID方法

在這裏插入圖片描述
思路比較簡單,就是利用互信息(mutual information,MI)的角度,增加teacher網絡與student網絡中間層特徵的MI,motivation是因爲MI可以表示兩個變量的依賴程度,MI越大,表明兩者的輸出越相關。
首先定義輸入數據xp(x)\bm{x}\sim p(\bm{x}),給定一個樣本x\bm{x},得到關於teacher和student輸出的KK個對集合R={(t(k),s(k))}k=1K\mathcal{R}=\{(\bm{t}^{(k)},\bm{s}^{(k)})\}_{k=1}^{K},KK表示選擇的層數。變量對的MI被定義爲I(t;s)=H(t)H(ts)=Et[logp(t)]+Et,s[logp(ts)]I(\bm{t};\bm{s})=H(\bm{t})-H(\bm{t}|\bm{s})\\ =-\mathbb{E}_{\bm{t}}[\log p(\bm{t})]+\mathbb{E}_{\bm{t,s}}[\log p(\bm{t|s})]
之後可以設計如下的loss函數來增大teacher和student之間的輸出特徵的互信息:
L=LSk=1KλkI(t(k),s(k))\mathcal{L}=\mathcal{L_{S}}-\sum_{k=1}^{K}\lambda_{k}I(\bm{t}^{(k)},\bm{s}^{(k)})
其中LS\mathcal{L_{S}}表示task-specific的誤差,λk\lambda_{k}是超參數用於平衡誤差。因爲精確的計算MI是困難的,這裏採用了變分下界(variational lower bound)的trick,採用variational的思想使用一個variational分佈q(ts)q(\bm{t}|\bm{s})去近似真實分佈p(ts)p(\bm{t}|\bm{s})
Note that variational的思想就是針對某個分佈很難求解的時候,採用另外一個分佈來近似這個分佈的做法,並使用變分信息最大化 (論文:The IM algorithm: A variational approach to information maximization) 的方法求解變分下界(variational low bound),這方法也被用在InfoGAN中。
I(t;s)=H(t)H(ts)=H(t)+Et,s[logp(ts)]=H(t)+Et,s[logq(ts)]+Es[DKL(p(ts)q(ts))]H(t)+Et,s[logq(ts)]I(\bm{t};\bm{s})=H(\bm{t})-H(\bm{t}|\bm{s})\\ =H(\bm{t})+\mathbb{E}_{\bm{t,s}}[\log p(\bm{t|s})]\\ =H(\bm{t})+\mathbb{E}_{\bm{t,s}}[\log q(\bm{t|s})]+\mathbb{E}_{\bm{s}}[D_{KL}(p(\bm{t|s})||q(\bm{t|s}))]\\ \geq H(\bm{t})+\mathbb{E}_{\bm{t,s}}[\log q(\bm{t|s})]
Et,s[logp(ts)]=Et,s[logq(ts)]+Es[DKL(p(ts)q(ts))]\mathbb{E}_{\bm{t,s}}[\log p(\bm{t|s})]=\mathbb{E}_{\bm{t,s}}[\log q(\bm{t|s})]+\mathbb{E}_{\bm{s}}[D_{KL}(p(\bm{t|s})||q(\bm{t|s}))]這個關係是由變分信息最大化中得到的,真實分佈logp(ts)\log p(\bm{t|s})的期望等於變分分佈Et,s[logq(ts)]\mathbb{E}_{\bm{t,s}}[\log q(\bm{t|s})]的期望+兩分佈的KL散度期望。因爲KL散度的值是恆大於0的,所以得到變分下界。進一步可以得到如下的誤差函數:
L~=LSk=1KλkEt(k),s(k)[logq(t(k)s(k))]\mathcal{\tilde{L}}=\mathcal{L_{S}}-\sum_{k=1}^{K}\lambda_{k}\mathbb{E}_{\bm{t^{(k)},s^{(k)}}}[\log q(\bm{t^{(k)}|s^{(k)}})]
H(t)H(\bm{t})由於和待優化的student參數無關,所以是常數。聯合的訓練學生網絡利用target task和最大化條件似然去擬合teacher激活值。

作者採用高斯分佈來實例化變分分佈,這裏的採用heteroscedastic的均值μ()\bm{\mu}(\cdot),即μ()\bm{\mu}(\cdot)是關於student輸出的函數;同時採用homoscedastic的方差σ\bm{\sigma},即不是關於student輸出的函數,作者嘗試採用heteroscedastic的均值σ()\bm{\sigma}(\cdot),但是容易訓練不穩定且提升不大。μ()\bm{\mu}(\cdot)其實就是相當於在feature KD時teacher與student之間的迴歸器,包含卷積等操作。
logq(ts)=c=1Ch=1Hw=1Wlogq(tc,h,ws)=c=1Ch=1Hw=1Wlogσc+(tc,h,wμc,h,w(s))22σc2+constant-\log q(\bm{t|s})=-\sum_{c=1}^{C}\sum_{h=1}^{H}\sum_{w=1}^{W}\log q(t_{c,h,w}|\bm{s})\\ =\sum_{c=1}^{C}\sum_{h=1}^{H}\sum_{w=1}^{W}\log \sigma_{c}+\frac{(t_{c,h,w}-\mu_{c,h,w}(\bm{s}))^{2}}{2\sigma_{c}^{2}}+\rm{constant}
σc=log(1+exp(αc))\sigma_{c}=\log(1+exp(\alpha_{c}))αc\alpha_{c}是一個可學習的參數。
對於logit層,logq(ts)=n=1Nlogq(tns)=n=1Nlogσn+(tnμn(s))22σn2+constant-\log q(\bm{t|s})=-\sum_{n=1}^{N}\log q(t_{n}|\bm{s})\\ =\sum_{n=1}^{N}\log \sigma_{n}+\frac{(t_{n}-\mu_{n}(\bm{s}))^{2}}{2\sigma_{n}^{2}}+\rm{constant}
這裏μ()\bm{\mu}(\cdot)是一個線性的變換矩陣。

與MSE的區別

作者認爲當前基於MSE的方法是該方法在方差相同時的特例,即爲:
logq(ts)=n=1N(tnμn(s))22+constant-\log q(\bm{t|s})=\sum_{n=1}^{N}\frac{(t_{n}-\mu_{n}(\bm{s}))^{2}}{2}+\rm{constant}
VID比MSE的好處爲建模了不同維度的方差,使得更加靈活的方式來避免一些model capacity用來到一些無用的信息。MSE採用一樣的方差會高度限制student,如果teacher的無用信息也同樣的地位擬合,會造成過擬合問題,浪費掉了student的網絡capacity。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章