神經網絡架構搜索——可微分搜索(DARTS)

背景

神經網絡架構搜索之前主流的方法主要包括:強化學習,進化學習。他們的搜索空間都是不可微的,Differentiable Architecture Search 這篇文章提出了一種可微的方法,可以用梯度下降來解決架構搜索的問題,所以在搜索效率上比之前不可微的方法快幾個數量級。可以這樣通俗的理解:之前不可微的方法,相當於是你定義了一個搜索空間(比如3x3和5x5的卷積核),然後神經網絡的每一層你可以從搜索空間中選一種構成一個神經網絡,跑一下這個神經網絡的訓練結果,然後不斷測試其他的神經網絡組合。這種方法,本質上是從很多的組合當中儘快的搜索到效果很好的一種,但是這個過程是黑盒,需要有大量的驗證過程,所以會很耗時。而這篇文章把架構搜索融合到模型當中一起訓練

算法核心思想

DARTS算法示意圖

由上圖可分析:

  • (a) 定義了一個cell單元,可看成有向無環圖,裏面4個node,node之間的edge代表可能的操作(如:3x3 sep 卷積),初始化時unknown。

  • (b) 把搜索空間連續鬆弛化,每個edge看成是所有子操作的混合(softmax權值疊加)。

  • © 聯合優化,更新子操作混合概率上的edge超參(即架構搜索任務)和 架構無關的網絡參數

  • (d) 優化完畢後,inference 直接取概率最大的子操作即可。

搜索空間

DARTS要做的事情,是訓練出來兩個Cell(Norm-Cell和Reduce-Cell),然後把Cell相連構成一個大網絡,而超參數layers可以控制有多少個cell相連,例如layers = 20表示有20個cell前後相連。

  • Norm-Cell: [輸入與輸出的FeatureMap尺寸保持一致]
  • Reduce-Cell: [輸出的FeatureMap尺寸減小一半]
Cell的組成

Cell由輸入節點,中間節點,輸出節點,邊四部分構成,我們規定每一個cell有兩個輸入節點一個輸出節點,Norm-Cell和Reduce-Cell的結構相同,不過操作不同。

  • 輸入節點:對於卷積網絡來說,兩個輸入節點分別是前兩層(layers)cell的輸出,對於循環網絡(Recurrent)來說,輸入時當前層的輸入和前一層的狀態。

  • 中間節點:每一箇中間節點都由它的前繼通過邊再求和得來。

  • 輸出節點:由每一箇中間節點concat起來。

  • 邊:邊代表的是operation(比如3*3的卷積),在收斂得到結構的過程中,兩兩節點中間所有的邊(DARTS預定義了8中不同的操作)都會存在並參與訓練,最後加權平均,這個權就是我們要訓練的東西,我們希望得到的結果是效果最好的邊它的權重最大。

DARTS實際預定義的Cell結構與論文中示意圖的表示略有不同,完整的Cell結構包含兩個輸入節點,四個中間節點和一個輸出節點,如下圖所示:

Search-Cell結構

全連接的情況下,N0中間節點有兩個前繼節點;N1,N2,N3分別有3,4,5個前繼節點。每個節點之間有對應8個不同的預定義操作,共同構成一組邊。

首先我們定義如下公式用softmax歸一化alpha處理一組邊:

oˉ(i,j)(x)=oOexp(αo(i,j))oOexp(αo(i,j))o(x) \bar{o}^{(i, j)}(x)=\sum_{o \in \mathcal{O}} \frac{\exp \left(\alpha_{o}^{(i, j)}\right)}{\sum_{o^{\prime} \in \mathcal{O}} \exp \left(\alpha_{o^{\prime}}^{(i, j)}\right)} o(x)
通過公式可知每個操作對應一個權值(即alpha),這就是我們要訓練的參數,我們把這些alpha稱作一個權值矩陣,alpha值越大代表的操作在這組邊中越重要。

然後每組中間節點公式表示如下,即所有前繼節點累加作爲當前節點的輸出:
x(i)=j<io(i,j)(x(j)) x^{(i)}=\sum_{j<i} o^{(i, j)}\left(x^{(j)}\right)

我們收斂到最後希望得到一個權值矩陣,這個矩陣當中權值越大的邊,留下來之後效果越好。

優化策略

通過前面定義的搜索空間,我們的目的是通過梯度下降優化alpha矩陣。我們把神經網絡原有的權重稱爲W矩陣。爲了實現端到端的優化,我們希望同時優化兩個矩陣使得結果變好。上述兩層優化是有嚴格層次的,爲了使兩者都能同時達到優化的策略,一個樸素的想法是:在訓練集上固定alpha矩陣的值,然後梯度下降W矩陣的值,在驗證集上固定W矩陣的值,然後梯度下降alpha的值,循環往復直到這兩個值都比較理想。這個過程有點像k-means的過程,先定了中心,再求均值,再換中心,再求均值。需要注意的是驗證集和訓練集的劃分比例是1:1的,因爲對於alpha矩陣來說,驗證集就是它的訓練集。
minαLval(w(α),α) s.t. w(α)=argminwLtrain(w,α) \begin{array}{cl} \min _{\alpha} & \mathcal{L}_{v a l}\left(w^{*}(\alpha), \alpha\right) \\ \text { s.t. } & w^{*}(\alpha)=\operatorname{argmin}_{w} \mathcal{L}_{t r a i n}(w, \alpha) \end{array}
但是這個方法雖然可以工作,但是效果不是很好,由於這種雙優化的問題很難求得精確解(因爲需要反覆迭代求解兩個參數),所以採用一種近似的迭代優化步驟來交替更新兩個參數,算法如下:

20190714153014.png

具體的公式推導流程可參考(DARTS公式推導 https://zhuanlan.zhihu.com/p/73037439)
αLval(ω(α),α)αLval(ωξωLtrain(ω,α),α) \begin{aligned} \nabla_{\alpha} \mathcal{L}_{v a l}\left(\omega^{*}(\alpha), \alpha\right) \\ \approx \nabla_{\alpha} \mathcal{L}_{v a l}(&\left.\omega-\xi \nabla_{\omega} \mathcal{L}_{t r a i n}(\omega, \alpha), \alpha\right) \end{aligned}

αLval(ωξωLtrain(ω,α),α)=αLval(ω,α)ξα,ω2Ltrain(ω,α)ωLval(ω,α) \begin{aligned} & \nabla_{\alpha} \mathcal{L}_{v a l}\left(\omega-\xi \nabla_{\omega} \mathcal{L}_{t r a i n}(\omega, \alpha), \alpha\right) \\ =& \nabla_{\alpha} \mathcal{L}_{v a l}\left(\omega^{\prime}, \alpha\right)-\xi \nabla_{\alpha, \omega}^{2} \mathcal{L}_{t r a i n}(\omega, \alpha) \cdot \nabla_{\omega^{\prime}} \mathcal{L}_{v a l}\left(\omega^{\prime}, \alpha\right) \end{aligned}

α,w2Ltrain(w,α)wLval(w,α)αLtrain(w+,α)αLtrain(w,α)2ϵ \nabla_{\alpha, w}^{2} \mathcal{L}_{t r a i n}(w, \alpha) \nabla_{w^{\prime}} \mathcal{L}_{v a l}\left(w^{\prime}, \alpha\right) \approx \frac{\nabla_{\alpha} \mathcal{L}_{t r a i n}\left(w^{+}, \alpha\right)-\nabla_{\alpha} \mathcal{L}_{t r a i n}\left(w^{-}, \alpha\right)}{2 \epsilon}

生成最終Cell結構

根據前面所述,我們要訓練出來一個alpha矩陣,使得權重大的邊保留下來,所以在這個結構收斂了之後還需要做一個生成最終Cell的過程。那這個時候你可能會問,爲什麼不把之前的結構直接用上呢?因爲邊太多,結構太複雜,參數太多不好訓練,所以作者希望能生成一個更簡單的網絡結構,接下來我們說生成的方法。

對於每一箇中間節點來說,我們最多保留兩個最強壯的前繼;對於兩兩節點之間的邊,我們只保留權重最大的一條邊,我們定義一下什麼是最強壯的前繼。假設一個節點有三個前繼,那我們選哪兩個呢?把前繼和當前節點之間權重最高的那條邊代表那個前繼的強壯程度,我們選最強壯的兩個前繼。節點之間只保留權重最大的那條邊。

image-20200516172750548

normal cell search

reduce cell search

網絡結構堆疊

下圖,展示了Normal-Cell與Reduce-Cell的連接方式,代碼描述是在1/3處和2/3處添加兩個Reduce-Cell。比如,在CIFAR-10數據集上的網絡結構需要20個Cell串聯。NetWork=6*Normal-Cell+Reduce-Cell+6*Normal-Cell+Reduce-Cell+6*Normal-Cell

Norm-Cell與Reduce-Cell串聯

由於,Cell結構是兩個輸入的,因此詳細的Cell連接方式如下所示:

具體連接方式

結果

CIFAR-10

CIFAR-10結果

ImageNet

ImageNet結果

參考

Liu, H., Simonyan, K., & Yang, Y. (2019). DARTS: Differentiable Architecture Search. ArXiv, abs/1806.09055.

DARTS 可微 架構搜索
AutoDL

DARTS公式推導


更多內容關注微信公衆號【AI異構】

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章