【論文筆記】DARTS: Differentiable Architecture Search

論文:https://arxiv.org/pdf/1806.09055.pdf
源碼:https://github.com/quark0/darts

看過NAS的同學都知道,之前神經網絡結構搜索使用的都是強化學習或者進化算法來搜索,當然PNAS是之後的事情,因爲動作空間是離散的,在全局搜索神經網絡架構十分耗費資源,而DARTS這篇文章用了數學方法,巧妙地將搜索空間用概率的方式轉化爲了連續可微形式,然後使用梯度的下降來搜索網絡結構。
--------------------------------------------------------------------------------------看完之後感慨甚多,果然數學纔是最強有力的殺器。

簡介

現有的神經網絡結構搜索的方法雖然有效但是需要大量的計算資源支撐,例如在CIFAR-10和ImageNet數據集上進行搜索,強化學習需要2000 GPU day 進化算法需要3150 GPU day。本文爲了解決這個問題,從一個不同的角度出發,提出了一種有效的結構搜索辦法叫做DARTS(Differentiable ARchitecture Search)。將原有從一個離散的候選結構搜索空間內搜索,替換成一個連續的搜索空間,在這種連續的空間就可以利用可微的性質,此時就可以將搜索任務變化成新的優化目標通過梯度下降的方式來進行優化,就相當於原來是強化學習在離散空間內反覆試錯,現在變成了網絡中的一個優化目標可以直接把網絡結構學出來。
DARTS適用於CNN也適用於RNN,並且相對於離散空間內搜索可微的方法速度快了幾個量級

Contribution

1.本文提出了一種適用於卷積神經網絡和循環神經網絡的可微神經網絡搜索方法
2.通過大量的實驗證明了本文提出的方法在圖片和語言處理模型上都遠遠優於不可微分的搜索方法
3.我們實現了高效的結構搜索,全都歸功於使用了基於梯度的優化方法
4.DARTS具有可遷移的特點

方法

搜索空間

搜索cell作爲最終神經網絡結構的組成單元,然後學習將這些cell堆疊成卷積神經網絡或者循環神經網絡。
一個cell是一個有向無環圖DAG由N個有序節點組成。
每個節點x(i)x^{(i)}都是一個特徵映射,每條有向邊o(i,j)o^{(i,j)}都代表了一種運算操作
如果有一個節點有兩個輸入的話就把兩個輸入做concatenation
每個節點都是通過他之前所有節點的運算得到的(按照tenorflow計算圖理解),用公式表示成如下形式
x(j)=i<jo(i,j)(x(i))x^{(j)}=\sum_{i<j}o^{(i,j)}(x^{(i)})
其中還包含了一個零的操作,用來減少連接
在這裏插入圖片描述

連續鬆弛優化

OO表示候選離散的操作集合(卷積,池化,零)爲了讓空間連續,對所有的結構計算softmax
在這裏插入圖片描述
其中運算操作混合權重的一組節點可以表示爲一個向量α(i,j)\alpha^{(i,j)},搜索完成後,使用最可能的操作替換每一個混合運算操作符就可以得到一個離散的結構。
在這裏插入圖片描述
鬆弛完成後,我們的目標是學習結構α\alpha和權重weightweight w,類似於強化學習或者進化算法,將驗證集上的性能看做最終的獎勵或者擬合程度,DARTS 的目標就是優化驗證集上的loss。
LtrainL_{train}LvalL_{val}分別表示訓練和驗證的loss。這兩個loss不僅決定了結構α\alpha 而且也決定了網絡中的權重w。結構的搜索目標是找到α\alpha^*使Lval(w,α)L_{val}(w^*,\alpha^*)最小,其中ww^*通過最小化訓練loss得到
w=argmaxwLtrain(w,α)w^* = argmax_wL_{train}(w,\alpha^*)公式表示如下
在這裏插入圖片描述
DARTS的僞代碼
在這裏插入圖片描述

近似結構梯度

由於上述兩層嵌套loss內部優化過程比較昂貴,導致無法求解出準確的結構梯度,爲此我們提出了一種簡單地近似方法
在這裏插入圖片描述
其中w表示算法的當前權重,ξ\xi是內部優化步驟的學習率,目的是通過僅使用單步訓練來調整權重w, 來近似ω(α)\omega^*(\alpha)(這個想法真的秀)需要注意的是在公式(6),如果w已經是最優的情況下,此時ξwLtrain(w,α)=0\xi▽_wL_{train}(w,\alpha)=0,這個時候公式(6)就可以化簡成 αLval(w,α)▽_\alpha L_{val}(w,\alpha)
鏈式法則應用到結構梯度上
在這裏插入圖片描述
其中 w=wξwLtrain(w,α)w'=w-\xi▽_w L_{train}(w,\alpha) 表示一次步的前向傳播的權重,上述表達式後面包含了一個計算複雜度很高的矩陣乘法,文中提出有限差分近似的方法解決
在這裏插入圖片描述
此時的時間複雜度就從O(αw)O(|\alpha||w|)降低成了O(α+w)O(|\alpha|+|w|)
(看到這真的驚到我了,真的是很把數學用的出神入化,學好數學是有多麼重要)

推導離散結構

爲了形成離散結構中的每個node,我們從之前的非零候選操作集中取出top-k重要的操作。這些操作的重要性是通過softmax計算得到的。k的取值,CNN中的cell時,k取2,RNN中的cell時 k取1。

在這裏插入圖片描述

實驗

實驗部分,主要在Cifar-10和PTB上進行實驗,分別驗證DARTS在CNN和RNN上的搜索能力
在這裏插入圖片描述
在這裏插入圖片描述
在這裏插入圖片描述
在這裏插入圖片描述

讀了這篇paper,給了我很大的衝擊,這種將搜索空間轉化成可微形式進行優化,是我之前想都沒想過的,同時本文的方法優化部分環環相扣,各種數學優化的trick層出不窮,也讓我清晰的認識到,數學纔是AI領域真正的內功。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章