gradnorm論文地址:https://arxiv.org/abs/1711.02257
gradnorm是一種優化方法,在多任務學習(Multi-Task Learning)中,解決 1. 不同任務loss梯度的量級(magnitude)不同,造成有的task在梯度反向傳播中占主導地位,模型過分學習該任務而忽視其它任務;2. 不同任務收斂速度不一致;這兩個問題。
從實現上來看,gradnorm除了利用label loss更新神經網絡的參數外,還會使用grad loss更新每個任務(task)的損失(loss)在總損失中的權重w。
引言
以簡單的多任務學習模型shared bottom爲例,兩個任務的shared bottom結構如下,輸出的兩個tower分別擬合兩個任務。
針對這樣的模型,最簡單的方法就是每個任務單獨計算損失,然後彙總起來,最終的損失函數如下:
loss(t)=lossA(t)+lossB(t)
但是,兩個任務的loss反向傳播的梯度量級可能不同,在反向傳播到shared bottom部分時,梯度量級小的任務對模型參數更新的比重少,使得shared bottom對該任務的學習不充分。因此,我們可以簡單的引入權重,平衡梯度,如下:
loss(t)=wA×lossA(t)+wB×lossB(t)
這樣做並沒有很好的解決問題,首先,如果loss權重w在訓練過程中爲定值,最初梯度量級大的任務,我們給一個小的w,到訓練結束,這個小的w會一直限制這一任務,使得這一任務不能得到很好的學習。因此,需要梯度也是不斷變化的,更新公式如下:
loss(t)=wA(t)×lossA(t)+wB(t)×lossB(t)
gradnorm就是用梯度,來動態調整loss的w的優化方法。
gradnorm
想要動態更新loss的w,最直觀的方法就是利用grad,因爲在多任務學習中,我們解決的就是多任務梯度不平衡的問題,如果我們能知道w的更新梯度(這裏的梯度不是神經網絡參數的梯度,是loss權重w的梯度),就可以利用梯度更新公式,來動態更新w,就像更新神經網絡的參數一樣,如下,其中λ沿用全局的神經網絡學習率。
w(t+1)=w(t)+λβ(t)
我們的目的是平衡梯度,所以β最好是梯度關於w的倒數,爲此定義梯度損失如下:
Grad Loss=Σi∣∣∣GWi(t)−GW(t)×[ri(t)]α∣∣∣
GWi(t)=∣∣▽Wwi(t)Li(t)∣∣2
GW(t)=Etask[GWi(t)]
ri(t)=Etask[Li(t)]Li(t)
Li(t)=L0(t)Li(t)
這幾個公式就是論文最核心的部分,其中,Grad Loss定義爲,各個任務實際的梯度範數與理想的梯度範數的差的絕對值和;GWi(t)爲實際的梯度範數,GW(t)×[ri(t)]α爲理想的梯度範數;GWi(t)是任務i的帶權損失wi(t)Li(t),對需要更新的神經網絡參數W(W表示神經網絡參數,w表示loss權重)的梯度的L2範數;GW(t)是對所有任務求得的GWi(t)的平均;Li(t)表示任務i的反向訓練速度,Li(t)越大,Li(t)越大,任務i訓練越慢;ri(t)是任務i的相對反向訓練速度。
α是超參數,α越大,對訓練速度的平衡限制越強。爲了節約計算時間,Grad Loss僅對shared bottom的輸出部分計算。
有了Grad Loss,就可以利用Grad Loss對wi(t)求導,得到上面梯度更新公式中需要的β(t)。爲了防止wi(t)變爲0,在對Grad Loss求導時,認爲GW(t)×[ri(t)]α部分爲常數,即使其中有wi(t)。在每一個batch step的最後,爲了節藕gradnorm過程中,利用Grad Loss對wi(t)求導過程與全局訓練神經網絡的學習率的關係,會對wi(t)在進行Σiwi(t)=T的renormalize,T是任務總數。
gradnorm示意如下:
gradnorm在單個batch step的流程總結如下:
1.前向傳播計算總損失Loss=Σiwili;
2.計算GWi(t),ri(t),GWi(t);
3.計算Grad Loss;
4.計算Grad Loss對wi的導數;
5.利用第1步計算的的Loss反向傳播更新神經網絡參數;
6.利用第4步的導數更新wi(更新後在下一個batch step生效);
7.對wi進行renormalize(下一個batch step使用的是renormalize之後的wi)。
附上論文原版步驟:
參考文獻:
https://github.com/brianlan/pytorch-grad-norm