神經網絡架構搜索——可微分搜索(PC-DARTS)

華爲發表在ICLR 2020上的NAS工作,針對現有DARTS模型訓練時需要 Large memory and computing 問題,提出了 Partial Channel ConnectionEdge Normalization 的技術,在搜索過程中更快更好

  • aper: PC-DARTS: Partial Channel Connections for Memory-Efficient Differentiable Architecture Search
  • Code: https://github.com/yuhuixu1993/PC-DARTS

動機

接着上面的P-DARTS來看,儘管上面可以在17 cells情況下單卡完成搜索,但妥協犧牲的是operation的數量,這明顯不是個優秀的方案,故此文 Partially-Connected DARTS,致力於大規模節省計算量和memory,從而進行快速且大batchsize的搜索。

貢獻點

  • 設計了基於channel的sampling機制,故每次只有小部分1/K channel的node來進行operation search,減少了(K-1)/K 的memory,故batchsize可增大爲K倍。

  • 爲了解決上述channel採樣導致的不穩定性,提出了 邊緣正規化(edge normalization),在搜索時通過學習edge-level超參來減少不確定性。

方法

PC-DARTS架構

部分通道連接(Partial Channel Connection)

如上圖的上半部分,在所有的通道數K裏隨機採樣 1/K 出來,進行 operation search,然後operation 混合後的結果與剩下的 (K-1)/K 通道數進行 concat,公式表示如下:

fi,jPC(xi;Si,j)=oOexp{αi,jo}oOexp{αi,jo}o(Si,jxi)+(1Si,j)xi f_{i, j}^{\mathrm{PC}}\left(\mathbf{x}_{i} ; \mathbf{S}_{i, j}\right)=\sum_{o \in \mathcal{O}} \frac{\exp \left\{\alpha_{i, j}^{o}\right\}}{\sum_{o^{\prime} \in \mathcal{O}} \exp \left\{\alpha_{i, j}^{o^{\prime}}\right\}} \cdot o\left(\mathbf{S}_{i, j} * \mathbf{x}_{i}\right)+\left(1-\mathbf{S}_{i, j}\right) * \mathbf{x}_{i}

上述的“部分通道連接”操作會帶來一些正副作用:

  • 正作用:能減少operations選擇時的biases,弱化無參的子操作(Pooling, Skip-Connect)的作用。文中3.3節有這麼一句話:當proxy dataset非常難時(即ImageNet),往往一開始都會累積很大權重在weight-free operation,故制約了其在ImageNet上直接搜索的性能。
  • 副作用:由於網絡架構在不同iterations優化是基於隨機採樣的channels,故最優的edge連通性將會不穩定。
class MixedOp(nn.Module):

  def __init__(self, C, stride):
    super(MixedOp, self).__init__()
    self._ops = nn.ModuleList()
    self.mp = nn.MaxPool2d(2,2)

    for primitive in PRIMITIVES:
      op = OPS[primitive](C //4, stride, False)
      if 'pool' in primitive:
        op = nn.Sequential(op, nn.BatchNorm2d(C //4, affine=False))
      self._ops.append(op)

  def forward(self, x, weights):
    #channel proportion k=4(實驗證明1/4性能最佳)
    dim_2 = x.shape[1]
    xtemp = x[ : , :  dim_2//4, :, :] # channel 0到1/4的輸入
    xtemp2 = x[ : ,  dim_2//4:, :, :] # channel 1/4到1的輸入
    temp1 = sum(w * op(xtemp) for w, op in zip(weights, self._ops)) # 僅1/4數據參與ops運算
    #reduction cell 需要在concat之前添加pooling操作
    if temp1.shape[2] == x.shape[2]:
      ans = torch.cat([temp1,xtemp2],dim=1)
    else:
      ans = torch.cat([temp1,self.mp(xtemp2)], dim=1)
    ans = channel_shuffle(ans,4) # 一個cell完成後對channel進行隨機打散,爲下個cell做採樣準備
    #ans = torch.cat([ans[ : ,  dim_2//4:, :, :],ans[ : , :  dim_2//4, :, :]],dim=1)
    #except channe shuffle, channel shift also works
    return ans

def channel_shuffle(x, groups):
    batchsize, num_channels, height, width = x.data.size()

    channels_per_group = num_channels // groups
    
    # reshape [batchsize, num_channels, height, width] 
    # -> [batchsize, groups,channels_per_group, height, width]
    x = x.view(batchsize, groups, 
        channels_per_group, height, width)
		# 打亂channel的操作(藉助transpose後數據塊的stride發生變化,然後將其連續化)
    # 參考:https://www.cnblogs.com/aoru45/p/10974508.html
    x = torch.transpose(x, 1, 2).contiguous()

    # flatten
    x = x.view(batchsize, -1, height, width)

    return x

邊緣正規化(Edge Normalization)

爲了克服部分通道連接這個副作用,提出邊緣正規化(見上圖的下半部分),即把多個PC後的node輸入softmax權值疊加,類attention機制

xjPC=oOexp{αi,jo}oOexp{αi,jo}o(Si,jxi)+(1Si,j)xi \mathbf{x}_{j}^{\mathrm{PC}}=\sum_{o \in \mathcal{O}} \frac{\exp \left\{\alpha_{i, j}^{o}\right\}}{\sum_{o^{\prime} \in \mathcal{O}} \exp \left\{\alpha_{i, j}^{o^{\prime}}\right\}} \cdot o\left(\mathbf{S}_{i, j} * \mathbf{x}_{i}\right)+\left(1-\mathbf{S}_{i, j}\right) * \mathbf{x}_{i}

xjPC=i<jexp{βi,j}i<jexp{βi,j}fi,j(xi) \mathbf{x}_{j}^{\mathrm{PC}}=\sum_{i < j} \frac{\exp \left\{\beta_{i, j}\right\}}{\sum_{i^{\prime} < j} \exp \left\{\beta_{i^{\prime}, j}\right\}} \cdot f_{i, j}\left(\mathbf{x}_{i}\right)

由於edge 超參 βi,j\beta_{i, j} 在訓練階段是共享的,故學習到的網絡更少依賴於不同iterations間的採樣到的channels,使得網絡搜索過程更穩定。當網絡搜索完畢,node間的operation選擇由operation-level和edge-level的參數相乘後共同決定。

weights_normal = [F.softmax(alpha, dim=-1) for alpha in alpha_normal]
weights_reduce = [F.softmax(alpha, dim=-1) for alpha in alpha_reduce]
weights_edge_normal = [F.softmax(beta, dim=0) for beta in beta_normal]
weights_edge_reduce = [F.softmax(beta, dim=0) for beta in beta_reduce]


def parse(alpha, beta, k):
		...  
    for edges, w in zip(alpha, beta):
        edge_max, primitive_indices = torch.topk((w.view(-1, 1) * edges)[:, :-1], 1) # ignore 'none'
    ...

實驗結果

CIFAR-10

CIFAR-10結果

ImageNet

ImageNet結果

消融實驗

消融實驗

參考

[1] Yuhui Xu et al. ,PC-DARTS: Partial Channel Connections for Memory-Efficient Differentiable Architecture Search

[2] https://zhuanlan.zhihu.com/p/73740783


更多內容關注微信公衆號【AI異構】

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章