硬核推導Google AdaFactor:一個省顯存的寶藏優化器

一隻小狐狸帶你解鎖煉丹術&NLP祕籍

作者:蘇劍林(來自追一科技,人稱“蘇神”)

前言

自從GPT、BERT等預訓練模型流行起來後,其中一個明顯的趨勢是模型越做越大,因爲更大的模型配合更充分的預訓練通常能更有效地刷榜。不過,理想可以無限遠,現實通常很侷促,有時候模型太大了,大到哪怕你擁有了大顯存的GPU甚至TPU,依然會感到很絕望。比如GPT2最大的版本有15億參數,最大版本的T5模型參數量甚至去到了110億,這等規模的模型,哪怕在TPU集羣上也沒法跑到多大的batch size。

這時候通常要往優化過程着手,比如使用混合精度訓練(tensorflow下還可以使用一種叫做bfloat16的新型浮點格式),即省顯存又加速訓練;又或者使用更省顯存的優化器,比如RMSProp就比Adam更省顯存。本文則介紹AdaFactor,一個由Google提出來的新型優化器,首發論文爲《Adafactor: Adaptive Learning Rates with Sublinear Memory Cost》。

AdaFactor具有自適應學習率的特性,但比RMSProp還要省顯存,並且還針對性地解決了Adam的一些缺陷。

Adam

首先我們來回顧一下常用的Adam優化器的更新過程。設爲迭代步數,爲當前學習率,是損失函數,是待優化參數,則是防止溢出的小正數,那麼Adam的更新過程爲

要省顯存,就首先得知道顯存花在哪裏的。首先,計算量和顯存的大頭肯定都是,也就是說,計算梯度是很費資源的,這也是爲啥“ALBERT相比BERT參數量雖然少了那麼多,但訓練速度也沒見快多少”的原因了;除此之外,顯存的消耗主要是了,我們要維護兩組緩存變量,來滑動計算梯度的前兩階矩(也就是),用以計算參數的更新量。這兩組變量每一組都跟訓練參數本身一樣大,因此對於參數比較多的模型,兩組緩存變量所消耗的顯存也不少。

AdaFactor

在這一節中,我們會相對詳細地介紹一些AdaFactor優化器,介紹中會設計比較多的公式和推導。如果只求一個大致瞭解的讀者,可以自行跳過部分數學內容~

拋棄動量

我們知道,CV模型很多時候要靠“SGD+動量”來煉出最優效果來,自適應學習率優化器通常訓練不出最好的效果。但對於NLP模型來說,情況有點相反,自適應學習率顯得更重要一些,很少聽到由純靠SGD調NLP模型的案例。因此,作爲省顯存的第一步,我們可以拋棄Adam裏邊的動量,這樣就少一組緩存參數了,自然也就省了顯存:

這其實就是RMSProp的變種,比RMSProp多了這一步。

低秩分解

去掉之後,緩存變量直接減少了一半,但AdaFactor還不滿意,它希望保留自適應學習率功能,但把緩存變量的參數量再壓一壓。這一次,它用到了矩陣的低秩分解。

廣義KL散度

在SGD中,所有參數都是共用一個標量學習率;在Adam中,則是每一個參數都有自己的學習率。我們知道通過精調學習率,SGD其實也能有不錯的效果,這表明“每一個參數都有自己的學習率”這件事情都不是特別重要,或者換一種說法,就是“精調每一個參數自己的學習率”並不是特別重要。

這啓發我們,將換一種參數更少的近似可能也就足夠了。而“參數更少的近似”,我們就不難想到低秩分解了。對於的矩陣,我們希望找到的矩陣的矩陣,使得

足夠小時,的參數總量就小於的參數量。爲了“省”到極致,AdaFactor直接讓,即尋找使得

既然要近似,就要有一個度量的標準。很容易想到的標準是歐氏距離,即

但在這個距離之下,並沒有解析解;此外,在優化過程中(即)是非負的,而通過上述目標優化出來的無法保證非負,因此很可能擾亂優化過程。原論文的作者們很機智地換了一個度量標準,使得有解析解。具體來說,它使用了“廣義KL散度”,又稱“I散度”,其形式爲:

這個度量源自不等式,當且僅當時等號成立。所以代入,然後兩端乘,我們有

當且僅當成立,如果有多個分量,那麼對多個分量的結果求和即可,這就得到了度量。顯然,廣義KL散度是概率的KL散度的自然推廣,但它不要求滿足歸一化,只要求它們非負,這正好對應了AdaFactor的場景。而且巧妙的是,這種情形配上這個目標,剛好有解析解:

其實這個解析解也很形象,就是行、列分別求和,然後相乘,再除以全體的和。

推導過程

直接對求偏導數並讓偏導數等於0,得

整理得

注意到如果是一組最優解,那麼也是,說白了,所有的乘以一個常數,所有的也除以這個常數,是不變的。那麼我們就可以隨意指定,因爲它們就只是一個縮放標量而已。不失一般性,我們指定,那麼就解得

直觀理解

我們也可以從另一個角度理解結果。由於是非負的,我們可以將它歸一化,變成具有概率分佈的特性,即,然後我們試圖完成分解,由於現在相當於一個二元聯合概率分佈,那麼就相當於它們的邊緣分佈,即

現在還需要乘上一個,我們可以把它乘到中,不失一般性,我們假設乘到上,那麼就得到

AdaFactor雛形

有了結果後,我們就可以用它來構建更省內存的優化器了,這就是AdaFactor的雛形。簡單來說,當參數是普通一維向量時,優化過程保持不變;但的矩陣時,算出來的梯度也是矩陣,從而也是矩陣,這時候我們對做低秩分解,然後維護兩組緩存變量,分別滑動平均低秩分解後的結果,最後用共同調整學習率:

(把加到上去而不是上去,這是AdaFactor整出來的形式,不是筆者的鍋~).

滑動權重

在Adam以及上述AdaFactor雛形中,滑動權重都是恆爲常數,AdaFactor指出這是不科學的,並提出新的策略。

等價形式

爲了認識到這一點,我們重寫一下Adam的的更新過程:

所以如果設,那麼更新公式就是

問題是這個夠不夠合理呢?答案是可能不大夠。當,這時候就是,也就是用實時梯度來校正學習率,這時候校正力度最大;當時,,這時候是累積梯度平方與當前梯度平方的加權平均,由於,所以意味着當前梯度的權重不爲0,這可能導致訓練不穩定,因爲訓練後期梯度變小,訓練本身趨於穩定,校正學習率的意義就不大了,因此學習率的校正力度應該變小,並且,學習率最好恆定爲常數(這時候相當於退化爲SGD),這就要求時,

新的衰減策略

爲了達到這個目的,AdaFactor採用如下的衰減策略

它滿。但即便如此,也不是任何都適合,必須有好理解,那爲什麼要呢?原論文包含了對它的分析,大家可以去讀讀,但筆者覺得原論文的推導過於晦澀,所以這裏給出自己的理解。

首先,對於來說,一個很容易想到的方案是所有梯度平方的平均,即:

所以這等價於讓。這個方案美中不足的一點是,每一步梯度都是平權的,這不符合直覺,因爲正常來說越久遠的梯度應該越不重要纔對,所以應該適當降低歷史部分權重,而當時,,因此一個簡潔的方案是在式中取,AdaFactor默認的

層自適應

最後,我們還可以進一步根據參數的模長來校正更新量,這個思路來自LAMB優化器,在之前的文章《6個派生優化器的簡單介紹及其實現》中也介紹過。簡單來說,它就是將最後的更新量標準化,然後乘以參數的模長,說白了,就是不管你怎麼折騰,最後的更新量我只要你的方向,而大小由參數本身的模長和預先設置學習率共同決定,使得所有層所有參數的相對變化程度保持一致。

AdaFactor完整版

至此,我們終於可以寫出完整版AdaFactor的更新過程了:

其中是模長的變種,這一步相當於做了個截斷,即時才執行歸一化。原論文中的默認參數爲

如果參數是一維向量而不是矩陣,那麼使用普通的更新公式就行了。此外,論文還提出如果沒有傳入學習率,那麼可以使用爲默認學習率,但筆者看源碼的時候發現這個默認學習率很少使用,基本上還是需要自己傳入學習率的。

開源實現

爲了方便大家使用,筆者開源了自己實現的AdaFactor:

https://github.com/bojone/adafactor

開源包括純keras版和tf.keras版,使用方法跟普通keras優化器一樣,tf.keras版也可以當做一個普通的tensorflow優化器使用。開源實現參考了mesh_tensorflow版的源碼,在此表示感謝。優化器也已經內置在bert4keras中,方便大家調用。

需要提醒的是,用AdaFactor的時候,batch_size最好大一些,因爲本身低秩分解會帶來誤差,而如果batch_size過小,那麼梯度估算本身也帶來較大的誤差,兩者疊加優化過程可能還不收斂。對於預訓練模型來說,batch_size通常還是很大的,所以現在不少預訓練模型開始用AdaFactor優化器了;對於普通的下游任務來說,AdaFactor也可以嘗試,但可能需要多煉煉丹,才能搞出由於無腦Adam的效果。

文章小結

本文介紹了Google提出來的AdaFactor優化器,一個旨在減少顯存佔用的優化器,並且針對性地分析並解決了Adam的一些缺陷。筆者認爲,AdaFactor針對Adam所做的分析相當經典,值得我們認真琢磨體味,對有興趣研究優化問題的讀者來說,更是一個不可多得的分析案例。

當然,沒有什麼絕對能有效的方法,有的只是方法雖好,要想實際有效,依然要用心煉丹。 

夕小瑤的賣萌屋

_

關注&星標小夕,帶你解鎖AI祕籍

訂閱號主頁下方「撩一下」有驚喜哦

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章