神經網絡架構搜索——可微分搜索(PC-DARTS)
華爲發表在ICLR 2020上的NAS工作,針對現有DARTS模型訓練時需要 Large memory and computing 問題,提出了 Partial Channel Connection 和 Edge Normalization 的技術,在搜索過程中更快更好。
- aper: PC-DARTS: Partial Channel Connections for Memory-Efficient Differentiable Architecture Search
- Code: https://github.com/yuhuixu1993/PC-DARTS
動機
接着上面的P-DARTS來看,儘管上面可以在17 cells情況下單卡完成搜索,但妥協犧牲的是operation的數量,這明顯不是個優秀的方案,故此文 Partially-Connected DARTS,致力於大規模節省計算量和memory,從而進行快速且大batchsize的搜索。
貢獻點
-
設計了基於channel的sampling機制,故每次只有小部分1/K channel的node來進行operation search,減少了(K-1)/K 的memory,故batchsize可增大爲K倍。
-
爲了解決上述channel採樣導致的不穩定性,提出了 邊緣正規化(edge normalization),在搜索時通過學習edge-level超參來減少不確定性。
方法
部分通道連接(Partial Channel Connection)
如上圖的上半部分,在所有的通道數K裏隨機採樣 1/K 出來,進行 operation search,然後operation 混合後的結果與剩下的 (K-1)/K 通道數進行 concat,公式表示如下:
上述的“部分通道連接”操作會帶來一些正副作用:
- 正作用:能減少operations選擇時的biases,弱化無參的子操作(Pooling, Skip-Connect)的作用。文中3.3節有這麼一句話:當proxy dataset非常難時(即ImageNet),往往一開始都會累積很大權重在weight-free operation,故制約了其在ImageNet上直接搜索的性能。
- 副作用:由於網絡架構在不同iterations優化是基於隨機採樣的channels,故最優的edge連通性將會不穩定。
class MixedOp(nn.Module):
def __init__(self, C, stride):
super(MixedOp, self).__init__()
self._ops = nn.ModuleList()
self.mp = nn.MaxPool2d(2,2)
for primitive in PRIMITIVES:
op = OPS[primitive](C //4, stride, False)
if 'pool' in primitive:
op = nn.Sequential(op, nn.BatchNorm2d(C //4, affine=False))
self._ops.append(op)
def forward(self, x, weights):
#channel proportion k=4(實驗證明1/4性能最佳)
dim_2 = x.shape[1]
xtemp = x[ : , : dim_2//4, :, :] # channel 0到1/4的輸入
xtemp2 = x[ : , dim_2//4:, :, :] # channel 1/4到1的輸入
temp1 = sum(w * op(xtemp) for w, op in zip(weights, self._ops)) # 僅1/4數據參與ops運算
#reduction cell 需要在concat之前添加pooling操作
if temp1.shape[2] == x.shape[2]:
ans = torch.cat([temp1,xtemp2],dim=1)
else:
ans = torch.cat([temp1,self.mp(xtemp2)], dim=1)
ans = channel_shuffle(ans,4) # 一個cell完成後對channel進行隨機打散,爲下個cell做採樣準備
#ans = torch.cat([ans[ : , dim_2//4:, :, :],ans[ : , : dim_2//4, :, :]],dim=1)
#except channe shuffle, channel shift also works
return ans
def channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups
# reshape [batchsize, num_channels, height, width]
# -> [batchsize, groups,channels_per_group, height, width]
x = x.view(batchsize, groups,
channels_per_group, height, width)
# 打亂channel的操作(藉助transpose後數據塊的stride發生變化,然後將其連續化)
# 參考:https://www.cnblogs.com/aoru45/p/10974508.html
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batchsize, -1, height, width)
return x
邊緣正規化(Edge Normalization)
爲了克服部分通道連接這個副作用,提出邊緣正規化(見上圖的下半部分),即把多個PC後的node輸入softmax權值疊加,類attention機制
由於edge 超參 在訓練階段是共享的,故學習到的網絡更少依賴於不同iterations間的採樣到的channels,使得網絡搜索過程更穩定。當網絡搜索完畢,node間的operation選擇由operation-level和edge-level的參數相乘後共同決定。
weights_normal = [F.softmax(alpha, dim=-1) for alpha in alpha_normal]
weights_reduce = [F.softmax(alpha, dim=-1) for alpha in alpha_reduce]
weights_edge_normal = [F.softmax(beta, dim=0) for beta in beta_normal]
weights_edge_reduce = [F.softmax(beta, dim=0) for beta in beta_reduce]
def parse(alpha, beta, k):
...
for edges, w in zip(alpha, beta):
edge_max, primitive_indices = torch.topk((w.view(-1, 1) * edges)[:, :-1], 1) # ignore 'none'
...
實驗結果
CIFAR-10
ImageNet
消融實驗
參考
[1] Yuhui Xu et al. ,PC-DARTS: Partial Channel Connections for Memory-Efficient Differentiable Architecture Search
[2] https://zhuanlan.zhihu.com/p/73740783