神經網絡架構搜索——可微分搜索(SGAS)
KAUST&Intel發表在CVPR 2020上的NAS工作,針對現有DARTS框架在搜索階段具有高驗證集準確率的架構可能在評估階段表現不好的問題,提出了分解神經網絡架構搜索過程爲一系列子問題,SGAS使用貪婪策略選擇並剪枝候選操作的技術,在搜索CNN和GCN網絡架構均達到了SOTA。
- Paper: SGAS: Sequential Greedy Architecture Search
- Code: https://github.com/lightaime/sgas
動機
NAS技術都有一個通病:在搜索過程中驗證精度較高,但是在實際測試精度卻沒有那麼高。傳統的基於梯度搜索的DARTS技術,是根據block構建更大的超網,由於搜索的過程中驗證不充分,最終eval和test精度會出現鴻溝。從下圖的Kendall係數來看,DARTS搜出的網絡精度排名和實際訓練完成的精度排名偏差還是比較大。
方法
整體思路
本文使用與DARTS相同的搜索空間,SGAS搜索過程簡單易懂,如下圖所示。類似DARTS搜索過程爲每條邊指定參數α,超網訓練時通過文中判定規則逐漸確定每條邊的具體操作,搜索結束後即可得到最終模型。
爲了保證在貪心搜索的過程中能儘量保證搜索的全局最優性,進而引入了三個指標和兩個評估準則。
三個指標
邊的重要性
非零操作參數對應的softmax值求和,作爲邊的重要性衡量指標。
alphas = []
for i in range(4):
for n in range(2 + i):
alphas.append(Variable(1e-3 * torch.randn(8)))
# alphas經過訓練後
mat = F.softmax(torch.stack(alphas, dim=0), dim=-1).detach() # mat爲14*8維度的二維列表,softmax歸一化。
EI = torch.sum(mat[:, 1:], dim=-1) # EI爲14個數的一維列表,去掉none後的7個ops對應alpha值相加
選擇的準確性
計算操作分佈的標準化熵,熵越小確定性越高;熵越高確定性越小。
import torch.distributions.categorical as cate
probs = mat[:, 1:] / EI[:, None]
entropy = cate.Categorical(probs=probs).entropy() / math.log(probs.size()[1])
SC = 1-entropy
選擇的穩定性
將歷史信息納入操作分佈評估,使用直方圖交叉核計算平均選擇穩定性。直方圖交叉核的原理詳見(https://blog.csdn.net/hong__fang/article/details/50550656)。
def histogram_intersection(a, b):
c = np.minimum(a.cpu().numpy(),b.cpu().numpy())
c = torch.from_numpy(c).cuda()
sums = c.sum(dim=1)
return sums
def histogram_average(history, probs):
histogram_inter = torch.zeros(probs.shape[0], dtype=torch.float).cuda()
if not history:
return histogram_inter
for hist in history:
histogram_inter += utils.histogram_intersection(hist, probs)
histogram_inter /= len(history)
return histogram_inter
probs_history = []
probs_history.append(probs)
if (len(probs_history) > args.history_size):
probs_history.pop(0)
histogram_inter = histogram_average(probs_history, probs)
SS = histogram_inter
兩種評估準則
評估準則1:
選擇具有高邊緣重要性和高選擇確定性的操作
def normalize(v):
min_v = torch.min(v)
range_v = torch.max(v) - min_v
if range_v > 0:
normalized_v = (v - min_v) / range_v
else:
normalized_v = torch.zeros(v.size()).cuda()
return normalized_v
score = utils.normalize(EI) * utils.normalize(SC)
評估準則2:
在評估準則1的基礎上,加入考慮選擇穩定性
score = utils.normalize(EI) * utils.normalize(SC) * utils.normalize(SS)
實驗結果
CIFAR-10(CNN)
ImageNet(CNN)
ModelNet40(GCN)
PPI(GCN)
參考
[1] Li, Guohao et al. ,SGAS: Sequential Greedy Architecture Search
[2] https://zhuanlan.zhihu.com/p/134294068
[3] 直方圖交叉核 https://blog.csdn.net/hong__fang/article/details/50550656