【神經網絡搜索】DARTS: Differentiable Architecture Search

【GiantPandaCV】DARTS將離散的搜索空間鬆弛,從而可以用梯度的方式進行優化,從而求解神經網絡搜索問題。本文首發於GiantPandaCV,未經允許,不得轉載。https://arxiv.org/pdf/1806.09055v2.pdf

1. 簡介

此論文之前的NAS大部分都是使用強化學習或者進化算法等在離散的搜索空間中找到最優的網絡結構。而DARTS的出現,開闢了一個新的分支,將離散的搜索空間進行鬆弛,得到連續的搜索空間,進而可以使用梯度優化的方處理神經網絡搜索問題。DARTS將NAS建模爲一個兩級優化問題(Bi-Level Optimization),通過使用Gradient Decent的方法進行交替優化,從而可以求解出最優的網絡架構。DARTS也屬於One-Shot NAS的方法,也就是先構建一個超網,然後從超網中得到最優子網絡的方法。

2. 貢獻

DARTS文章一共有三個貢獻:

  • 基於二級最優化方法提出了一個全新的可微分的神經網絡搜索方法。
  • 在CIFAR-10和PTB(NLP數據集)上都達到了非常好的結果。
  • 和之前的不可微分方式的網絡搜索相比,效率大幅度提升,可以在單個GPU上訓練出一個滿意的模型。

筆者這裏補一張對比圖,來自之前筆者翻譯的一篇綜述:<NAS的挑戰和解決方案-一份全面的綜述>

ImageNet上各種方法對比,DARTS屬於Gradient Optimization方法

簡單一對比,DARTS開創的Gradient Optimization方法使用的GPU Days就可以看出結果非常驚人,與基於強化學習、進化算法等相比,DARTS不愧是年輕人的第一個NAS模型。

3. 方法

DARTS採用的是Cell-Based網絡架構搜索方法,也分爲Normal Cell和Reduction Cell兩種,分別搜索完成以後會通過拼接的方式形成完整網絡。在DARTS中假設每個Cell都有兩個輸入,一個輸出。對於Convolution Cell來說,輸入的節點是前兩層的輸出;對於Recurrent Cell來說,輸入爲當前步和上一步的隱藏狀態。

DARTS核心方法可以用下面這四個圖來講解。

DARTS Overview

(a) 圖是一個有向無環圖,並且每個後邊的節點都會與前邊的節點相連,比如節點3一定會和節點0,1,2都相連。這裏的節點可以理解爲特徵圖;邊代表採用的操作,比如卷積、池化等。

引入數學標記:

節點(特徵圖)爲: \(x^{(i)}\) 代表第i個節點對應的潛在特徵表示(特徵圖)。

邊(操作)爲: \(o^{(i,j)}\) 代表從第i個節點到第j個節點採用的操作。

每個節點的輸入輸出如下面公式表示,每個節點都會和之前的節點相連接,然後將結果通過求和的方式得到第j個節點的特徵圖。

\[x^{(j)}=\sum_{i\lt j} o^{(i, j)}(x^{(i)}) \]

所有的候選操作爲 \(\mathcal{O}\), 在DARTS中包括了3x3深度可分離卷積、5x5深度可分離卷積、3x3空洞卷積、5x5空洞卷積、3x3最大化池化、3x3平均池化,恆等,直連,共8個操作。

(b) 圖是一個超網,將每個邊都擴展了8個操作,通過這種方式可以將離散的搜索空間鬆弛化。具體的操作根據如下公式:

\[\bar{o}^{(i, j)}(x)=\sum_{o \in \mathcal{O}} \frac{\exp \left(\alpha_{o}^{(i, j)}\right)}{\sum_{o^{\prime} \in \mathcal{O}} \exp \left(\alpha_{o^{\prime}}^{(i, j)}\right)} o(x) \]

這個可以分爲兩個部分理解,一個是\(o(x)\)代表操作,一個代表選擇概率 \(\frac{\exp \left(\alpha_{o}^{(i, j)}\right)}{\sum_{o^{\prime} \in \mathcal{O}} \exp \left(\alpha_{o^{\prime}}^{(i, j)}\right)}\),這是一個softmax構成的概率,其中\(\alpha_o^{(i,j)}\)表示 第i個節點到第j個節點之間操作的權重,這也是之後需要搜索的網絡結構參數,會影響該操作的概率。即以下公式:

\[softmax(\alpha)\times operation_{w}(x) \]

左側代表當前操作的概率,右側代表當前操作的參數。

(c)和(d)圖 是保留的邊,訓練完成以後,從所有的邊中找到概率最大的邊,即以下公式:

\[o^{(i, j)}=\operatorname{argmax}_{o \in \mathcal{O}} \alpha_{o}^{(i, j)} \]

4. 數學推導

DARTS將NAS問題看作二級最優化問題,具體定義如下:

\[\begin{aligned} \min _{\alpha} & \mathcal{L}_{v a l}\left(w^{*}(\alpha), \alpha\right) \\ \text { s.t. } & w^{*}(\alpha)=\operatorname{argmin}_{w} \mathcal{L}_{\text {train }}(w, \alpha) \end{aligned} \]

\(w*(\alpha)\) 代表當前網絡結構參數\(\alpha\)的情況下,訓練獲得的最優的網絡結構參數。

第一行代表:在驗證數據集中,在特定網絡操作參數w下,通過訓練獲得最優的網絡結構參數\(\alpha\)

第二行表示:在訓練數據集中,在特定網絡結構參數\(\alpha\)下,通過訓練獲得最優的網絡操作參數\(w\)

條件:在結構確定的情況下,獲得最優的網絡操作權重

​ ----- 結構確定,訓練好卷積核

目標:在網絡操作權重確定的情況下,獲得最優的結構

​ ----- 卷積核不動,選擇更好的結構

最簡單的方法是通過交替優化參數\(w\)和參數\(\alpha\), 來得到最優的結果,僞代碼如下:

DARTS僞代碼

交替優化的複雜度非常高,是\(O(|\alpha||w|)\), 這種複雜度不可能投入使用,所以要將複雜度進行優化,用複雜度低的公式近似目標函數。

\[\nabla_{\alpha} \mathcal{L}_{\text {val }}\left(w^{*}(\alpha), \alpha\right) \approx \nabla_{\alpha} \mathcal{L}_{v a l}\left(w-\xi \nabla_{w} \mathcal{L}_{t r a i n}(w, \alpha), \alpha\right) \]

這種近似方法在Meta Learning中經常用到,詳見《Model-agnostic meta-learning for fast adaptation of deep networks》,也就是通過使用單個step的訓練調整w,讓這個結果來近似\(w*(\alpha)\)

然後對右側公式進行推導,得到梯度優化以後的表達式:

師兄提供


這裏求梯度使用的是鏈式法則,回顧一下:

\[z=f(g1(x),g2(x)) \]

則梯度計算爲:

\[\frac{\partial z}{\partial x}=\frac{\partial g1}{\partial x} \times \frac{\partial z}{\partial g1} + \frac{\partial g2}{\partial x}\times\frac{\partial z}{\partial g2} \]

或者

師兄提供

上述公式中Di代表對\(f(g1(\alpha),g2(\alpha))\)的第i項的偏導。


手敲公式太痛苦了

整理以後結果就是:

計算結果

減號後邊的是二次梯度,權重的梯度求解很麻煩,這裏使用泰勒公式將二階轉爲一階(h是一個很小的值)。

泰勒公式複習

利用最右下角的公式:

\(A=\nabla_{\omega^{\prime}} \mathcal{L}_{v a l}\left(\omega^{\prime}, \alpha\right)\),\(h=\epsilon\), \(x_0=w\), \(f=\nabla_{\alpha} \mathcal{L}_{\text {train }}(\cdot, \cdot)\), 代入可得(其中經驗上設置\(\epsilon=\frac{0.01}{||\nabla_{w'}\mathcal{L}_{val}(w',\alpha)||_2}\))

\[\nabla_{\alpha, \omega}^{2} \mathcal{L}_{\text {train }}(\omega, \alpha) \cdot \nabla_{\omega^{\prime}} \mathcal{L}_{\text {val }}\left(\omega^{\prime}, \alpha\right) \approx \frac{\nabla_{\alpha} \mathcal{L}_{\text {train }}\left(\omega^{+}, \alpha\right)-\nabla_{\alpha} \mathcal{L}_{\text {train }}\left(\omega^{-}, \alpha\right)}{2 \epsilon} \]

其中

\[\omega^{\pm}=\omega \pm \epsilon \nabla_{\omega^{\prime}} \mathcal{L}_{v a l}\left(\omega^{\prime}, \alpha\right) \]

這樣就可以將二次梯度轉化爲多個一次梯度。到這裏複雜度從\(O(|\alpha||w|)\)優化到\(O(|\alpha|+|w|)\)

一階近似:\(\xi=0\), 下面式子的二階倒數部分就消失了,這樣模型的梯度計算可能不夠準確,效果雖然不如二階,但是計算速度快。只需要假設當前的\(w\)就是\(w*(\alpha)\), 然後啓發式優化驗證集上的loss值即可。

計算結果

代碼實現上也有一定的區別,代碼將在下一篇講解。

5. 實驗設置

這裏我們暫且先關注CIFAR10上的實驗效果。DARTS構成網絡的方式之前已經提到了,首先爲每個單元內布使用DARTS進行搜索,通過在驗證集上的表現決定最好的單元然後使用這些單元構建更大的網絡架構,然後從頭開始訓練,報告在測試集上的表現。

CIFAR10上搜索操作有:

  • 3x3 & 5x5 可分離卷積
  • 3x3 & 5x5 空洞可分離卷積
  • 3x3 max & avg pooling
  • identiy
  • zero

實驗詳細設置:

  • 所有操作的stride=1, 爲了保證他們空間分辨率,使用了padding。

  • 卷積操作使用的是ReLU-Conv-BN的順序,並且每個可分離卷積會被使用兩次。

  • 卷積單元包括了7個節點,輸出節點爲所有中間節點concate以後的結果。

  • 網絡整體深度的1/3和2/3處強制設置了reduction cell來降低空間分辨率。

  • 網絡結構參數\(\alpha_{\text{normal}}\)是被所有normal cell共享的,同理\(\alpha_{\text{reduce}}\)是被所有reduction cell共享的。

  • 並沒有使用全局batch normalization, 使用的是batch-specific statistic batch normalization

  • CIFAR10一半的訓練集作爲驗證集。

  • 8個單元的消亡了使用DARTS訓練50個epoch, batch size設置爲64, 初始通道個數爲16。

  • 使用momentum SGD來優化權重,初始學習率設置爲0.025,momentum 0.9 weight decay爲0.0004.

  • 網絡架構參數\(\alpha\) 使用0作爲初始化,使用Adam優化器來優化\(\alpha\)參數,初始學習率設置爲0.0004,momentum爲(0.5,0.999)weight decay=0.001。

CIFAR10上搜索結果和其他算法對比

可以看到,搜索結果最終是優於AmoebaNet-A和NASNet-A。具體搜索得到的Normal Cell和Reduction Cell可視化如下:

Normal Cell & Reduction Cell for CIFAR10

網絡評價

網絡優化對初始化值是非常敏感的,爲了確定最終的網絡結構,DARTS將使用隨機種子運行四次,每次得到的Cell都會在訓練集上從頭開始訓練很短一段時間大概100 epochs , 然後根據驗證集上得到的最優結果決定最終的架構。

爲了驗證被選擇的架構:

  • 隨機初始化權重
  • 從頭開始訓練
  • 報告測試集上的模型表現

CIFAR10搜索的模型遷移到ImageNet更多細節:

  • 20個單元的大型網絡使用了96的batch size, 訓練了600個epoch
  • 初始通道個數由16修改爲36,爲了讓模型的參數和其他模型參數量相當。
  • 其他參數設置和搜索過程中參數一樣
  • 使用了cutout的數據增強方法,以0.2的概率進行path dropout
  • 使用了auxiliary tower(輔助頭,在這裏施加loss, 提前進行反向傳播,InceptionV3中提出)
  • 使用PyTorch在單個GPU上花費1.5天時間訓練完ImageNet,獨立訓練10次作爲最終的結果。

CIFAR10上搜索結果

使用二階優化方法+cutout的數據增強方法,DARTS能達到約2.76的準確率,筆者使用nni進行了實驗,最終結果是2.6%的Test Error。

nni上darts的實驗結果

6. 致謝&參考

感謝師兄提供的資料,以及知乎上兩位大佬,他們文章鏈接如下:

薰風讀論文|DARTS—年輕人的第一個NAS模型 https://zhuanlan.zhihu.com/p/156832334

【論文筆記】DARTS公式推導 https://zhuanlan.zhihu.com/p/73037439

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章