【GiantPandaCV導語】本文介紹的是Efficient Neural Architecture Search方法,主要是爲了解決之前NAS中無法完成權重重用的問題,首次提出了參數共享Parameter Sharing的方法來訓練網絡,要比原先標準的NAS方法降低了1000倍的計算代價。從一個大的計算圖中挑選出最優的子圖就是ENAS的核心思想,而子圖之間都是共享權重的。
1. 摘要
ENAS是一個快速、代價低的自動網絡設計方法。在ENAS中,控制器controller通過在大的計算圖中搜索挑選一個最優的子圖來得到網絡結構。
- controller使用Policy Gradient算法進行訓練,通過最大化驗證集上的期望準確率作爲獎勵reward。
- 被挑選的子圖將使用經典的CrossEntropy Loss進行訓練。
子網絡之間的權重共享可以讓ENAS性能更強大的性能,同時要比經典的NAS方法降低了約1000倍的計算代價。
2. 簡介
NAS-RL使用了450個GPU訓練了3-4天,花費了32,400-43,200個GPU hours纔可以訓練出一個合適的網絡,需要大量的計算資源。NAS的計算瓶頸就在於需要讓每個子模型從頭開始收斂,訓練完成後就廢棄掉其訓練好的權重。
本文主要貢獻是通過讓所有子模型共享權重、避免從頭開始訓練,從而有效提升了NAS的訓練效率。隨後的子模型可以通過遷移學習的方法加速收斂速度、從而加速訓練。
ENAS可以做到使用單個NVIDIA GTX 1080Ti顯卡,只需要花費16個小時。同時在CIFAR10上可以達到2.89%的test error。
3. 方法
3.1 一個例子
ENAS可以看作是從一個超網中得到一個自網絡,如下圖所示。6個節點相互連接得到的就是超網(是一個有向無環圖),通過controller得到紅色的路徑就是其中的一個子網絡。
舉一個具體的例子,假設當前有4個節點:
上圖是controller,具體實現是一個LSTM,需要做出以下決策:
- 激活哪個邊
- 對應Node選擇什麼操作
第一個Node,controller首先採樣一個激活函數,這裏採用的是tanh,然後這個激活會接收x和h作爲輸入。
第二個Node,先採樣上一個index=1,說明Node2應該和Node1相連接;然後再採樣一個激活函數relu。
第三個Node,先採樣上一個index=2,說明Node3應該和Node2相連接;然後採樣一個激活函數Relu。
第四個Node,先採樣上一個index=1,說明Node4應該和Node1相連接,然後採樣一個激活函數tanh。
結束後發現有兩個節點是loose end, ENAS的做法是將兩者結果做一個平均,得到最終輸出。
在上述例子中,假設節點數量爲N,一共使用了4個激活函數可選。搜索空間大小爲:\(4^N\times N!\)
其中\(4^N\)代表N個節點可選的4個激活函數組成的空間,\(N!\) 代表節點的連接情況,之所以是階乘也很容易理解,因爲隨後的Node只能連接之前出現過的Node。
3.2 ENAS訓練流程
在ENAS中,有兩組可學習參數,Controller LSTM中的參數\(\theta\) 和 子模型共享的權重參數\(w\)。具體流程是:
- LSTM sample出一個子模型,然後訓練模型\(w\), 通過標準的反向傳播算法進行訓練,訓練完成以後在驗證集上進行測試。
- 通過驗證集上結果反饋給LSTM,計算\(\theta\)的梯度,更新LSTM的參數。
- 如此反覆,可以訓練出一個LSTM能夠讓模型在驗證集上的性能最佳。
第一步:訓練共享參數w
首先固定住controller的參數,然後使用蒙特卡洛估計來計算梯度,更新w權重:
m是從\(\pi(m;\theta)\) 中採樣得到的模型,對於所有的模型計算模型損失函數的期望。右側公式是梯度的無偏估計。
第二步:訓練controller 參數\(\theta\)
這一步固定住w,更新controller參數,希望可以得到的Reward值(也就是驗證集準確率)儘可能大。
這裏使用的是REINFORCE算法來進行計算的,具體內容可以查看NAS-RL那篇文章中的講解。
3.3 marco search space
有了上邊的例子做鋪墊,卷積的這部分就很好理解了,區別有幾點:
- 節點操作不同,這裏可以是3x3卷積、5x5卷積、平均池化、3x3最大池化、3x3深度可分離卷積,5x5深度可分離卷積 一共六個操作。
- 上圖Node3輸出了兩個值,代表先將node1和node2的輸出tensor合併,然後在經過maxpool操作。
計算卷積網絡設計的空間複雜度,對於第k個節點,頂多可以選取k-1個層,所以在第k層就有\(2^{k-1}\)種選擇,而這裏假設一共有L個層需要做從6個候選操作中做選擇。那麼在不考慮連線的情況下就有\(6^L\)可能被挑選的操作,由於所有連線都是獨立事件,那複雜度計算就是:\(6^L\times 2^{L(L-1)/2}\)(除以2是因爲連線具有對稱性,採樣1,2和2,1結果是一致的)。
3.4 micro search space
ENAS中首次提出了搜索一個一個單元,然後將單元組合拼接成整個網絡。其中單元分爲兩種類型,一種是Conv Cell 該單元不改變特徵圖的空間分辨率;另外一種是Reduction Cell 該單元會將空間分辨率降低爲原來的一半。
假定每個cell裏邊有B個節點,由於網絡設定是node1和node2是單元的輸入,所以剛開始這部分需要特殊處理,固定兩個單元,搜索隨後的單元,即還剩下B-2個節點需要搜索。
如上圖所示,從node3開始生成,首先生成兩個需要連接的兩個對象,indexA和indexB; 然後生成兩個op, 分別是sep 5x5和直連id。將操作sep 5x5施加到indexA對應節點上;將操作直連施加到indexB對應節點上,然後通過add的方式融合特徵。
搜索空間複雜度計算:首先分爲Conv Cell和Reduction Cell,由於他們並沒有本質不同,只是所有的操作的stride設置爲2,複雜度也是一樣的。
假定當前是第i個節點,可以選擇來自先前i-1個節點中的兩個節點,並且可選操作有5個。假設只選擇一個節點,那麼複雜度是\(5\times (B-2)!\), 由於要選擇兩個節點,兩個節點的選擇是互相獨立的,所以複雜度計算變爲:\((5\times (B-2)!)^2\) 。而又有Reduction Cell和Conv Cell也是互相獨立的,所以複雜度變爲\((5\times (B-2)!)^4\) ,計算完畢。
4. 實驗結果
主要是在NLP中常用的語料庫Penn Treebank和CV中經典的數據集CIFAR-10上進行了實驗。
4.1 語言模型
在單個GTX 1080Ti上訓練了10個小時,達到了55.8的test perplexity, 下圖是通過ENAS找到的RNN單元。
結果如下:
4.2 圖像分類
數據集:CIFAR10有5w張訓練圖片和1w張測試圖片,使用標準的數據預處理和數據增強方法:如將訓練圖片padding到40x40大小,然後隨機裁剪到32x32,水平隨機反轉。
訓練細節: 共享權重w使用Nesterov momentum來訓練,使用cosine schedule調整lr,lr最大設置爲0.05,最小設置爲0.001,T0=10, Tmul=2。每個子網絡設置運行310個epoch。權重初始化使用He initialization。weight decay設置爲\(10^{-4}\)。
controller的設置細節,policy gradient的權重\(\theta\)使用均勻的從[-0.1,0.1]初始化,使用0.00035的學習率,使用Adam優化器,設置tanh常數爲2.5 temerature 設置爲5.0; 給controller 得到的熵添加0.1的權重。
在macro搜索空間中,通過在skip connection兩層之間添加KL 散度來增加稀疏性, KL散度項對應的權重設置爲0.8.
實驗結果對比如下:
5. 代碼實現
代碼這裏參考NNI中的實現,以macro爲例,ENASLayer實現如下:
class ENASLayer(mutables.MutableScope):
def __init__(self, key, prev_labels, in_filters, out_filters):
super().__init__(key)
self.in_filters = in_filters
self.out_filters = out_filters
self.mutable = mutables.LayerChoice([
ConvBranch(in_filters, out_filters, 3, 1, 1, separable=False),
ConvBranch(in_filters, out_filters, 3, 1, 1, separable=True),
ConvBranch(in_filters, out_filters, 5, 1, 2, separable=False),
ConvBranch(in_filters, out_filters, 5, 1, 2, separable=True),
PoolBranch('avg', in_filters, out_filters, 3, 1, 1),
PoolBranch('max', in_filters, out_filters, 3, 1, 1),
SEConvBranch(in_filters, out_filters, 3, 1, 1, reduction=4)
])
if len(prev_labels) > 0:
self.skipconnect = mutables.InputChoice(
choose_from=prev_labels, n_chosen=None)
else:
self.skipconnect = None
self.batch_norm = nn.BatchNorm2d(out_filters, affine=False)
def forward(self, prev_layers):
out = self.mutable(prev_layers[-1])
if self.skipconnect is not None:
connection = self.skipconnect(prev_layers[:-1])
if connection is not None:
out += connection
return self.batch_norm(out)
其中的mutables是NNI中的一個核心類,可以從LayerChoice所提供的選擇中挑選一個操作,其中最後一個SEConvBranch是筆者自己補充上去的。
- mutable LayerChoice就是從備選選項中選擇其中一個操作
- mutable InputChoice是選擇前幾層節點進行連接。
主幹網絡如下:
class GeneralNetwork(nn.Module):
def __init__(self, num_layers=6, out_filters=12, in_channels=3, num_classes=10,
dropout_rate=0.0):
super().__init__()
self.num_layers = num_layers
self.num_classes = num_classes
self.out_filters = out_filters
self.dropout_rate = dropout_rate
self.stem = nn.Sequential(
nn.Conv2d(in_channels, out_filters, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_filters)
)
pool_distance = self.num_layers // 3
# 進行pool操作是num_layers // 3
self.pool_layers_idx = [pool_distance - 1, 2 * pool_distance - 1]
self.dropout = nn.Dropout(self.dropout_rate)
self.layers = nn.ModuleList() # convolutional
self.pool_layers = nn.ModuleList() # reduction
labels = []
for layer_id in range(self.num_layers): # 設置12個layer
labels.append("layer_{}".format(layer_id))
if layer_id in self.pool_layers_idx: # 如果使用pool
self.pool_layers.append(FactorizedReduce(
self.out_filters, self.out_filters))
self.layers.append( # 相當於Node節點
ENASLayer(labels[-1], labels[:-1], self.out_filters, self.out_filters))
self.gap = nn.AdaptiveAvgPool2d(1)
self.dense = nn.Linear(self.out_filters, self.num_classes)
def forward(self, x):
bs = x.size(0)
cur = self.stem(x)
layers = [cur]
for layer_id in range(self.num_layers):
cur = self.layers[layer_id](layers)
layers.append(cur)
if layer_id in self.pool_layers_idx:
# 如果輪到了池化層
for i, layer in enumerate(layers):
layers[i] = self.pool_layers[self.pool_layers_idx.index(
layer_id)](layer)
cur = layers[-1]
cur = self.gap(cur).view(bs, -1)
cur = self.dropout(cur)
logits = self.dense(cur)
return logits
需要注意有幾個點:
- self.stem是第一個node,手動設置的。
- 池化是強制設置的,在某些層規定進行下采樣。
搜索過程調用了NNI提供的API:
model = GeneralNetwork()
trainer = enas.EnasTrainer(model,
loss=criterion,
metrics=accuracy,
reward_function=reward_accuracy,
optimizer=optimizer,
callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")],
batch_size=args.batch_size,
num_epochs=num_epochs,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
log_frequency=args.log_frequency,
mutator=mutator)
mutator是NNI提供的一個類,就是上述提到的controller,這裏具體調用的是EnasMutator。
def _sample_layer_choice(self, mutable):
# 選擇 某個層 只需要選一個就可以了
self._lstm_next_step() # 讓_inputs在lstm中進行一次前向傳播
logit = self.soft(self._h[-1]) # linear 從隱藏層embedd得到可選的層的邏輯評分
if self.temperature is not None:
logit /= self.temperature # 一個常量 貌似是RL中的trick
if self.tanh_constant is not None:
# tanh_constant * tanh(logits) 用tanh再激活一次(可選)
logit = self.tanh_constant * torch.tanh(logit)
if mutable.key in self.bias_dict:
logit += self.bias_dict[mutable.key]
# 對卷積層進行了偏好處理,如果是卷積層,那就在對應的值加上一個0.25,增大被選中的概率
# softmax, view(-1),
branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
# 依據概率來選下角標,如果數量不爲1,選擇的多箇中沒有重複的
# eg: [100,1,1] 最有可能選擇100對應的下標0
log_prob = self.cross_entropy_loss(logit, branch_id) # 交叉熵損失函數 - 判斷logit和branchid分佈是否相似程度
self.sample_log_prob += self.entropy_reduction(log_prob) # 求和或者求平均
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type ??
self.sample_entropy += self.entropy_reduction(entropy) # 樣本熵?
self._inputs = self.embedding(branch_id) # 得到對應id的embedding, 從選擇空間 - 映射到 - 隱空間
return F.one_hot(branch_id, num_classes=self.max_layer_choice).bool().view(-1) # 將選擇變成one_hot向量
這部分是EnasMutator中一個核心函數,實現的是REINFORCE算法。
if self.entropy_weight: # 交叉熵權重
reward += self.entropy_weight * self.mutator.sample_entropy.item() # 得到樣本熵
6. 總結
ENAS核心就是提出了一個超網,每次從超網中採樣一個小的網絡進行訓練。所有的子網絡都是共享超網中的一套參數,這樣每次訓練就不是從頭開始訓練,而是進行了遷移學習,加快了訓練速度。
有註釋代碼鏈接如下:https://github.com/pprp/SimpleCVReproduction/tree/master/nni