GradNorm:Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks,梯度歸一化

文章目錄

  gradnorm論文地址:https://arxiv.org/abs/1711.02257

  gradnorm是一種優化方法,在多任務學習(Multi-Task Learning)中,解決 1. 不同任務loss梯度的量級(magnitude)不同,造成有的task在梯度反向傳播中占主導地位,模型過分學習該任務而忽視其它任務;2. 不同任務收斂速度不一致;這兩個問題。

  從實現上來看,gradnorm除了利用label loss更新神經網絡的參數外,還會使用grad loss更新每個任務(task)的損失(loss)在總損失中的權重ww

引言

  以簡單的多任務學習模型shared bottom爲例,兩個任務的shared bottom結構如下,輸出的兩個tower分別擬合兩個任務。

  針對這樣的模型,最簡單的方法就是每個任務單獨計算損失,然後彙總起來,最終的損失函數如下:

loss(t)=lossA(t)+lossB(t)loss(t) = loss_{A}(t)+loss_{B}(t)

  但是,兩個任務的loss反向傳播的梯度量級可能不同,在反向傳播到shared bottom部分時,梯度量級小的任務對模型參數更新的比重少,使得shared bottom對該任務的學習不充分。因此,我們可以簡單的引入權重,平衡梯度,如下:

loss(t)=wA×lossA(t)+wB×lossB(t)loss(t) =w_{A}\times loss_{A}(t)+w_{B}\times loss_{B}(t)

  這樣做並沒有很好的解決問題,首先,如果loss權重ww在訓練過程中爲定值,最初梯度量級大的任務,我們給一個小的ww,到訓練結束,這個小的ww會一直限制這一任務,使得這一任務不能得到很好的學習。因此,需要梯度也是不斷變化的,更新公式如下:

loss(t)=wA(t)×lossA(t)+wB(t)×lossB(t)loss(t) =w_{A}(t)\times loss_{A}(t)+w_{B}(t)\times loss_{B}(t)

  gradnorm就是用梯度,來動態調整loss的ww的優化方法。

gradnorm

  想要動態更新loss的ww,最直觀的方法就是利用grad,因爲在多任務學習中,我們解決的就是多任務梯度不平衡的問題,如果我們能知道ww的更新梯度(這裏的梯度不是神經網絡參數的梯度,是loss權重ww的梯度),就可以利用梯度更新公式,來動態更新ww,就像更新神經網絡的參數一樣,如下,其中λ\lambda沿用全局的神經網絡學習率。

w(t+1)=w(t)+λβ(t)w(t+1) = w(t)+\lambda\beta (t)

  我們的目的是平衡梯度,所以β\beta最好是梯度關於ww的倒數,爲此定義梯度損失如下:

Grad Loss=ΣiGWi(t)GW(t)×[ri(t)]αGrad~Loss = \Sigma_{i}\Big|G_W^{i}(t)-\overline{G}_{W}(t)\times [r_i(t)]^{\alpha}\Big|

GWi(t)=Wwi(t)Li(t)2G_W^{i}(t)=||\bigtriangledown_Ww_i(t)L_i(t)||_2

GW(t)=Etask[GWi(t)]\overline{G}_W(t)=E_{task}[G_W^i(t)]

ri(t)=L~i(t)Etask[L~i(t)]r_i(t)=\frac{\widetilde{L}_{i}(t)}{E_{task}[\widetilde{L}_{i}(t)]}

L~i(t)=Li(t)L0(t)\widetilde{L}_{i}(t)=\frac{L_{i}(t)}{L_{0}(t)}

  這幾個公式就是論文最核心的部分,其中,Grad LossGrad~Loss定義爲,各個任務實際的梯度範數與理想的梯度範數的差的絕對值和;GWi(t)G_W^{i}(t)爲實際的梯度範數,GW(t)×[ri(t)]α\overline{G}_{W}(t)\times [r_i(t)]^{\alpha}爲理想的梯度範數;GWi(t)G_W^{i}(t)是任務ii的帶權損失wi(t)Li(t)w_i(t)L_i(t),對需要更新的神經網絡參數WWWW表示神經網絡參數,ww表示loss權重)的梯度的L2範數;GW(t)\overline{G}_W(t)是對所有任務求得的GWi(t)G_W^{i}(t)的平均;L~i(t)\widetilde{L}_{i}(t)表示任務ii的反向訓練速度,L~i(t)\widetilde{L}_{i}(t)越大,Li(t)L_{i}(t)越大,任務ii訓練越慢ri(t)r_i(t)是任務ii的相對反向訓練速度。

  α\alpha是超參數,α\alpha越大,對訓練速度的平衡限制越強。爲了節約計算時間,Grad LossGrad~Loss僅對shared bottom的輸出部分計算。

  有了Grad LossGrad~Loss,就可以利用Grad LossGrad~Losswi(t)w_i(t)求導,得到上面梯度更新公式中需要的β(t)\beta(t)。爲了防止wi(t)w_i(t)變爲0,在對Grad LossGrad~Loss求導時,認爲GW(t)×[ri(t)]α\overline{G}_{W}(t)\times [r_i(t)]^{\alpha}部分爲常數,即使其中有wi(t)w_i(t)。在每一個batch step的最後,爲了節藕gradnorm過程中,利用Grad LossGrad~Losswi(t)w_i(t)求導過程與全局訓練神經網絡的學習率的關係,會對wi(t)w_i(t)在進行Σiwi(t)=T\Sigma_{i}w_i(t)=T的renormalize,TT是任務總數。

  gradnorm示意如下:

  

  gradnorm在單個batch step的流程總結如下:

1.前向傳播計算總損失Loss=ΣiwiliLoss=\Sigma_iw_il_i;
2.計算GWi(t)G_W^{i}(t)ri(t)r_i(t)GWi(t)\overline{G}_W^{i}(t)
3.計算Grad LossGrad~Loss
4.計算Grad LossGrad~Losswiw_i的導數;
5.利用第1步計算的的LossLoss反向傳播更新神經網絡參數;
6.利用第4步的導數更新wiw_i(更新後在下一個batch step生效);
7.對wiw_i進行renormalize(下一個batch step使用的是renormalize之後的wiw_i)。

  附上論文原版步驟:

  

參考文獻:
https://github.com/brianlan/pytorch-grad-norm

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章