【神經網絡搜索】Efficient Neural Architecture Search

【GiantPandaCV導語】本文介紹的是Efficient Neural Architecture Search方法,主要是爲了解決之前NAS中無法完成權重重用的問題,首次提出了參數共享Parameter Sharing的方法來訓練網絡,要比原先標準的NAS方法降低了1000倍的計算代價。從一個大的計算圖中挑選出最優的子圖就是ENAS的核心思想,而子圖之間都是共享權重的。

https://arxiv.org/pdf/1802.03268v2.pdf

1. 摘要

ENAS是一個快速、代價低的自動網絡設計方法。在ENAS中,控制器controller通過在大的計算圖中搜索挑選一個最優的子圖來得到網絡結構。

  • controller使用Policy Gradient算法進行訓練,通過最大化驗證集上的期望準確率作爲獎勵reward。
  • 被挑選的子圖將使用經典的CrossEntropy Loss進行訓練。

子網絡之間的權重共享可以讓ENAS性能更強大的性能,同時要比經典的NAS方法降低了約1000倍的計算代價。

2. 簡介

NAS-RL使用了450個GPU訓練了3-4天,花費了32,400-43,200個GPU hours纔可以訓練出一個合適的網絡,需要大量的計算資源。NAS的計算瓶頸就在於需要讓每個子模型從頭開始收斂,訓練完成後就廢棄掉其訓練好的權重。

本文主要貢獻是通過讓所有子模型共享權重、避免從頭開始訓練,從而有效提升了NAS的訓練效率。隨後的子模型可以通過遷移學習的方法加速收斂速度、從而加速訓練。

ENAS可以做到使用單個NVIDIA GTX 1080Ti顯卡,只需要花費16個小時。同時在CIFAR10上可以達到2.89%的test error。

3. 方法

3.1 一個例子

ENAS可以看作是從一個超網中得到一個自網絡,如下圖所示。6個節點相互連接得到的就是超網(是一個有向無環圖),通過controller得到紅色的路徑就是其中的一個子網絡。

節點代表局部計算、邊代表信息的流動

舉一個具體的例子,假設當前有4個節點:

Controller示意圖

上圖是controller,具體實現是一個LSTM,需要做出以下決策:

  • 激活哪個邊
  • 對應Node選擇什麼操作

第一個Node,controller首先採樣一個激活函數,這裏採用的是tanh,然後這個激活會接收x和h作爲輸入。

第二個Node,先採樣上一個index=1,說明Node2應該和Node1相連接;然後再採樣一個激活函數relu。

第三個Node,先採樣上一個index=2,說明Node3應該和Node2相連接;然後採樣一個激活函數Relu。

第四個Node,先採樣上一個index=1,說明Node4應該和Node1相連接,然後採樣一個激活函數tanh。

結束後發現有兩個節點是loose end, ENAS的做法是將兩者結果做一個平均,得到最終輸出。

超圖和搜索得到的子網絡結果

在上述例子中,假設節點數量爲N,一共使用了4個激活函數可選。搜索空間大小爲:\(4^N\times N!\)

其中\(4^N\)代表N個節點可選的4個激活函數組成的空間,\(N!\) 代表節點的連接情況,之所以是階乘也很容易理解,因爲隨後的Node只能連接之前出現過的Node。

3.2 ENAS訓練流程

在ENAS中,有兩組可學習參數,Controller LSTM中的參數\(\theta\) 和 子模型共享的權重參數\(w\)。具體流程是:

  • LSTM sample出一個子模型,然後訓練模型\(w\), 通過標準的反向傳播算法進行訓練,訓練完成以後在驗證集上進行測試。
  • 通過驗證集上結果反饋給LSTM,計算\(\theta\)的梯度,更新LSTM的參數。
  • 如此反覆,可以訓練出一個LSTM能夠讓模型在驗證集上的性能最佳。

第一步:訓練共享參數w

首先固定住controller的參數,然後使用蒙特卡洛估計來計算梯度,更新w權重:

m是從\(\pi(m;\theta)\) 中採樣得到的模型,對於所有的模型計算模型損失函數的期望。右側公式是梯度的無偏估計。

第二步:訓練controller 參數\(\theta\)

這一步固定住w,更新controller參數,希望可以得到的Reward值(也就是驗證集準確率)儘可能大。

這裏使用的是REINFORCE算法來進行計算的,具體內容可以查看NAS-RL那篇文章中的講解。

3.3 marco search space

有了上邊的例子做鋪墊,卷積的這部分就很好理解了,區別有幾點:

  • 節點操作不同,這裏可以是3x3卷積、5x5卷積、平均池化、3x3最大池化、3x3深度可分離卷積,5x5深度可分離卷積 一共六個操作。
  • 上圖Node3輸出了兩個值,代表先將node1和node2的輸出tensor合併,然後在經過maxpool操作。

計算卷積網絡設計的空間複雜度,對於第k個節點,頂多可以選取k-1個層,所以在第k層就有\(2^{k-1}\)種選擇,而這裏假設一共有L個層需要做從6個候選操作中做選擇。那麼在不考慮連線的情況下就有\(6^L\)可能被挑選的操作,由於所有連線都是獨立事件,那複雜度計算就是:\(6^L\times 2^{L(L-1)/2}\)(除以2是因爲連線具有對稱性,採樣1,2和2,1結果是一致的)。

3.4 micro search space

ENAS中首次提出了搜索一個一個單元,然後將單元組合拼接成整個網絡。其中單元分爲兩種類型,一種是Conv Cell 該單元不改變特徵圖的空間分辨率;另外一種是Reduction Cell 該單元會將空間分辨率降低爲原來的一半。

Cell-Based

假定每個cell裏邊有B個節點,由於網絡設定是node1和node2是單元的輸入,所以剛開始這部分需要特殊處理,固定兩個單元,搜索隨後的單元,即還剩下B-2個節點需要搜索。

Controller for cells

如上圖所示,從node3開始生成,首先生成兩個需要連接的兩個對象,indexA和indexB; 然後生成兩個op, 分別是sep 5x5和直連id。將操作sep 5x5施加到indexA對應節點上;將操作直連施加到indexB對應節點上,然後通過add的方式融合特徵。

生成的結果,注意前兩個node是固定的

搜索空間複雜度計算:首先分爲Conv Cell和Reduction Cell,由於他們並沒有本質不同,只是所有的操作的stride設置爲2,複雜度也是一樣的。

假定當前是第i個節點,可以選擇來自先前i-1個節點中的兩個節點,並且可選操作有5個。假設只選擇一個節點,那麼複雜度是\(5\times (B-2)!\), 由於要選擇兩個節點,兩個節點的選擇是互相獨立的,所以複雜度計算變爲:\((5\times (B-2)!)^2\) 。而又有Reduction Cell和Conv Cell也是互相獨立的,所以複雜度變爲\((5\times (B-2)!)^4\) ,計算完畢。

4. 實驗結果

主要是在NLP中常用的語料庫Penn Treebank和CV中經典的數據集CIFAR-10上進行了實驗。

4.1 語言模型

在單個GTX 1080Ti上訓練了10個小時,達到了55.8的test perplexity, 下圖是通過ENAS找到的RNN單元。

通過搜索發現的RNN單元

結果如下:

ENAS和其他結果對比

4.2 圖像分類

數據集:CIFAR10有5w張訓練圖片和1w張測試圖片,使用標準的數據預處理和數據增強方法:如將訓練圖片padding到40x40大小,然後隨機裁剪到32x32,水平隨機反轉。

訓練細節: 共享權重w使用Nesterov momentum來訓練,使用cosine schedule調整lr,lr最大設置爲0.05,最小設置爲0.001,T0=10, Tmul=2。每個子網絡設置運行310個epoch。權重初始化使用He initialization。weight decay設置爲\(10^{-4}\)

controller的設置細節,policy gradient的權重\(\theta\)使用均勻的從[-0.1,0.1]初始化,使用0.00035的學習率,使用Adam優化器,設置tanh常數爲2.5 temerature 設置爲5.0; 給controller 得到的熵添加0.1的權重。

在macro搜索空間中,通過在skip connection兩層之間添加KL 散度來增加稀疏性, KL散度項對應的權重設置爲0.8.

使用Macro空間得到的搜索結果

Micro空間搜索得到的結果

實驗結果對比如下:

實驗結果對比

5. 代碼實現

代碼這裏參考NNI中的實現,以macro爲例,ENASLayer實現如下:

class ENASLayer(mutables.MutableScope):

    def __init__(self, key, prev_labels, in_filters, out_filters):
        super().__init__(key)
        self.in_filters = in_filters
        self.out_filters = out_filters

        self.mutable = mutables.LayerChoice([
            ConvBranch(in_filters, out_filters, 3, 1, 1, separable=False),
            ConvBranch(in_filters, out_filters, 3, 1, 1, separable=True),
            ConvBranch(in_filters, out_filters, 5, 1, 2, separable=False),
            ConvBranch(in_filters, out_filters, 5, 1, 2, separable=True),
            PoolBranch('avg', in_filters, out_filters, 3, 1, 1),
            PoolBranch('max', in_filters, out_filters, 3, 1, 1),
            SEConvBranch(in_filters, out_filters, 3, 1, 1, reduction=4)
        ])
        if len(prev_labels) > 0:
            self.skipconnect = mutables.InputChoice(
                choose_from=prev_labels, n_chosen=None)
        else:
            self.skipconnect = None
        self.batch_norm = nn.BatchNorm2d(out_filters, affine=False)

    def forward(self, prev_layers):
        out = self.mutable(prev_layers[-1])
        if self.skipconnect is not None:
            connection = self.skipconnect(prev_layers[:-1])
            if connection is not None:
                out += connection
        return self.batch_norm(out)

其中的mutables是NNI中的一個核心類,可以從LayerChoice所提供的選擇中挑選一個操作,其中最後一個SEConvBranch是筆者自己補充上去的。

  • mutable LayerChoice就是從備選選項中選擇其中一個操作
  • mutable InputChoice是選擇前幾層節點進行連接。

主幹網絡如下:

class GeneralNetwork(nn.Module):
    def __init__(self, num_layers=6, out_filters=12, in_channels=3, num_classes=10,
                 dropout_rate=0.0):
        super().__init__()
        self.num_layers = num_layers
        self.num_classes = num_classes
        self.out_filters = out_filters
        self.dropout_rate = dropout_rate

        self.stem = nn.Sequential(
            nn.Conv2d(in_channels, out_filters, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_filters)
        )

        pool_distance = self.num_layers // 3
        # 進行pool操作是num_layers // 3
        self.pool_layers_idx = [pool_distance - 1, 2 * pool_distance - 1]
        self.dropout = nn.Dropout(self.dropout_rate)

        self.layers = nn.ModuleList()  # convolutional
        self.pool_layers = nn.ModuleList()  # reduction

        labels = []
        for layer_id in range(self.num_layers):  # 設置12個layer
            labels.append("layer_{}".format(layer_id))

            if layer_id in self.pool_layers_idx:  # 如果使用pool
                self.pool_layers.append(FactorizedReduce(
                    self.out_filters, self.out_filters))

            self.layers.append(  # 相當於Node節點
                ENASLayer(labels[-1], labels[:-1], self.out_filters, self.out_filters))

        self.gap = nn.AdaptiveAvgPool2d(1)
        self.dense = nn.Linear(self.out_filters, self.num_classes)

    def forward(self, x):
        bs = x.size(0)
        cur = self.stem(x)  

        layers = [cur]

        for layer_id in range(self.num_layers):
            cur = self.layers[layer_id](layers)
            layers.append(cur)
            if layer_id in self.pool_layers_idx:
                # 如果輪到了池化層
                for i, layer in enumerate(layers):
                    layers[i] = self.pool_layers[self.pool_layers_idx.index(
                        layer_id)](layer)
                cur = layers[-1]

        cur = self.gap(cur).view(bs, -1)
        cur = self.dropout(cur)
        logits = self.dense(cur)
        return logits

需要注意有幾個點:

  • self.stem是第一個node,手動設置的。
  • 池化是強制設置的,在某些層規定進行下采樣。

搜索過程調用了NNI提供的API:

model = GeneralNetwork()
trainer = enas.EnasTrainer(model,
                           loss=criterion,
                           metrics=accuracy,
                           reward_function=reward_accuracy,
                           optimizer=optimizer,
                           callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")],
                           batch_size=args.batch_size,
                           num_epochs=num_epochs,
                           dataset_train=dataset_train,
                           dataset_valid=dataset_valid,
                           log_frequency=args.log_frequency,
                           mutator=mutator)

mutator是NNI提供的一個類,就是上述提到的controller,這裏具體調用的是EnasMutator。

def _sample_layer_choice(self, mutable):
    # 選擇 某個層 只需要選一個就可以了
    self._lstm_next_step() # 讓_inputs在lstm中進行一次前向傳播

    logit = self.soft(self._h[-1]) # linear 從隱藏層embedd得到可選的層的邏輯評分

    if self.temperature is not None:
        logit /= self.temperature # 一個常量 貌似是RL中的trick

    if self.tanh_constant is not None:
        # tanh_constant * tanh(logits) 用tanh再激活一次(可選)
        logit = self.tanh_constant * torch.tanh(logit)

    if mutable.key in self.bias_dict:
        logit += self.bias_dict[mutable.key]
        # 對卷積層進行了偏好處理,如果是卷積層,那就在對應的值加上一個0.25,增大被選中的概率
    
    # softmax, view(-1), 
    branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) 
    # 依據概率來選下角標,如果數量不爲1,選擇的多箇中沒有重複的 
    # eg: [100,1,1] 最有可能選擇100對應的下標0
        
    log_prob = self.cross_entropy_loss(logit, branch_id) # 交叉熵損失函數 - 判斷logit和branchid分佈是否相似程度

    self.sample_log_prob += self.entropy_reduction(log_prob) # 求和或者求平均
    
    entropy = (log_prob * torch.exp(-log_prob)).detach()  # pylint: disable=invalid-unary-operand-type ??
    
    self.sample_entropy += self.entropy_reduction(entropy) # 樣本熵?

    self._inputs = self.embedding(branch_id) # 得到對應id的embedding, 從選擇空間 - 映射到 - 隱空間

    return F.one_hot(branch_id, num_classes=self.max_layer_choice).bool().view(-1) # 將選擇變成one_hot向量

這部分是EnasMutator中一個核心函數,實現的是REINFORCE算法。

if self.entropy_weight: # 交叉熵權重 
	reward += self.entropy_weight * self.mutator.sample_entropy.item() # 得到樣本熵

6. 總結

ENAS核心就是提出了一個超網,每次從超網中採樣一個小的網絡進行訓練。所有的子網絡都是共享超網中的一套參數,這樣每次訓練就不是從頭開始訓練,而是進行了遷移學習,加快了訓練速度。

有註釋代碼鏈接如下:https://github.com/pprp/SimpleCVReproduction/tree/master/nni

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章