1. INFO
Title: SMASH: One-Shot Model Architecture Search through HyperNetworks
Author: Andrew Brock, Theodore Lim, & J.M. Ritchie
Link: https://arxiv.org/pdf/1708.05344.pdf
Date: ICLR 2018 Poster
Code:https://github.com/ajbrock/SMASH
2. Motivation
高性能的深度神經網絡需要大量工程設計,而網絡的細節比如深度、連接方式等往往遵從了一些網絡設計規則,比如ResNet,Inception,FractalNets,DenseNets等。但即便在這些規則的指導下,也需要設計一系列實驗來決定網絡具體的配置,如深度,通道數等。
SMASH就是爲了越過這個昂貴的訓練候選模型的過程而提出來的,通過使用輔助網絡生成權重來解決這個問題。
3. Contribution
SMASH通過引入HyperNet來根據候選網絡的結構動態生成權重。
雖然通過輔助網絡生成候選網絡權重的方式得到的驗證集精度不高,但是不同候選網絡結果的表現和從頭訓練候選網絡的表現具有相對一致性,所以可以作爲挑選候選網絡的指導。
搜索網絡的設計使用了基於內存讀寫的靈活機制(memory read-write),從而定義了一個包括ResNet、DenseNet、FractalNets的搜索空間。
SMASH也有侷限性:這個方法不能自己發現全新的結構,因爲他只能動態生成模型參數的特定子集。
4. Method
SMASH的僞代碼如下:
- 首先設計所有候選網絡的搜索空間。
- 初始化超網權重H
- 對於一個batch的數據x,一個隨機的候選網絡c,和根據候選網絡架構的生成的權重W=H(c)
- 得到loss,進行反向傳播,並更新H
- 以上完成以後,就得到了一個訓練完成的HyperNet。
- 採樣隨機的候選網絡c,在驗證集上得到驗證loss,找到loss最小的候選網絡。
- 對得到的候選網絡進行從頭訓練,得到在驗證機上的精度。
對於以上過程,有幾個點需要搞清楚:
- 候選網絡搜索空間是如何構建的?
- 如何根據候選網絡架構得到權重?
4.1 候選網絡搜索空間
對於第一個問題,從內存讀寫的角度來考慮採樣複雜、帶分支的拓撲,並將該拓撲編碼爲二進制特徵向量。
普通的網絡是從前向傳播-反向傳播信號的角度來設計的,這篇文章從內存的讀寫角度來看待網絡結構,被稱爲Memory-Bank representation。
從這個角度,每個層就代表一個從內存中一部分讀取數據的操作,比如左邊的是resnet示意圖,從內存中讀取數據x,經過conv處理,得到conv(x),然後寫到內存中x+conv(x)結果。中間的圖展示的是DenseNet,回顧DenseNet,在每個block內部中,每個節點都和之前的所有節點相連接。
那麼在Memory-Bank的表示方法中,以3個節點爲例:
- 從第一塊內存讀取數據x
- 通過第一個conv1,得到conv1(x),並寫回第二塊內存。
- 從第二塊內存讀取conv1(x),經過第二個conv2,得到conv2(conv1(x))
- 從第一塊內存讀取x,經過第二個conv2,得到conv2(x)
- 兩者concate到一起寫回第三塊內存concat(conv2(conv1(x)), conv2(x))
SMASH採用的網絡結構和以上三種網絡類似,由多個block組成,其中降採樣部分使用的是1x1卷積,分辨率減半。其中全連接層和1x1卷積權重是通過學習得到的,不是通過HyperNet生成的。
下圖展示的是一個op的結構,一個1x1卷積在memory-bank上的操作,後邊跟着最多兩條卷積路徑。左側第一個灰色梯形代表1x1conv,用於調整channel個數,然後不同的分支代表選擇不同類型的卷積。
在採樣網絡的過程中,每個block內部的memory bank的個數是隨機的,每個memory-bank的channel個數也是隨機的。block中的層隨機選擇讀寫模型以及相對應op。
當讀入read了多個memory-bank, 在channel維度進行concat,寫入write是將每個memory-bank中的結果相加。
實驗中,op僅允許讀取所屬block的memory-bank。op有1x1卷積、若干常規卷積、非線性激活函數。
4.2 爲給定網絡生成權重
SMASH中提出的Dynamic Hypernet是基於網絡結構c得到對應的權重W。
優化的目標是學習一個映射W=H(c)能夠對任意一個架構c來說,H(c)能夠儘可能接近最優的W。
HyperNet是一個全卷積組成的網絡,所以其輸出的張量W隨着輸入網絡結構c的變化而變化,其標準的形式是4D的 BCHW,其中B=1。
舉個例子,如果op從第1,2,4個memory-bank中讀取,然後寫回第2,4個memory-bank。那麼第1,2,4個通道對應的值被賦值爲1(代表read模式,如上圖所示),第6(2+4),8(4+4)個通道被賦值爲1(代表write模式)。通過以上方式得到了對網絡結構的編碼。
通過以上例子,終於搞清楚瞭如何從memory-bank的角度來表徵網絡結構,剩下生成W權重的工作採用的是MLP來完成的。
5. Experiment
實驗部分需要驗證聲明Claim:
- 通過HyperNet生成權重W的這種方式具有較好的排序一致性。
- 證明SMASH方法的有效性,架構表徵c在整個算法中是否真正起到了作用。
- 驗證算法的可擴展性,使用遷移學習的方法來證明。
- 和其他SOTA方法進行比較
5.1 測試SMASH的相關性
橫座標是HyperNet生成權重得到的驗證集錯誤率,縱座標代表模型真實訓練得到的驗證集錯誤率,紅色線代表使用最小二乘法得到的結果。
根據這根線就得到了一致性?相當於使用目測的方法得到結論,感覺可以用統計學的方法計算出置信度,或者來計算一下kendall tau或者Person係數能更好的反映結果。
作者在說明這個結果的時候也很有意思:這個實驗結果只能表明在當前設置的實驗場景下是滿足相關性的,但既不能保證相關性的通用性,也不能保證相關性成立的條件。由於實驗代價過高,無法進一步得到結果。
所以需要第二個實驗來輔助,設計一個代理模型,其中的模型權重相比於正常的SMASH要小很多,減少了實驗代價。
而第三個實驗,提出了更高預算的網絡(因爲參數量不夠的情況下,模型表現會很差,爲了和SOTA比較,所以需要提高預算),但是實驗發現其SMASH分數和真實性能並不相關。發現了SMASH中存在的缺陷。
5.2 使用Proxy進一步證明SMASH有效性
左圖展示的是低預算情況下SMASH,並不存在明確的相關性。這說明在模型容量比較小的情況下,HyperNet很難學到良好的權值,導致SMASH得分和真實性能之前沒有明確的相關性。
通過破壞網絡架構表示c,發現對於給定的網絡架構,使用正確的網絡架構表示所能生成的SMASH驗證性能是最高的,證明了網絡架構表示的有效性。
5.3 遷移學習
作者發現這種算法搜出來的模型具有很好的可遷移性:CIFAR-100上搜出來的模型在STL-10上表現優於直接在STL-10上搜的模型。原因可能是CIFAR-100擁有更多的訓練樣例,能夠使HyperNet更好地選擇模型架構。
5.4 SOTA比較
可以看到,SMASH的結果並沒有達到SOTA,處於中等水平,但是總體上優於其他RL-based和進化算法。確實是比不上NASNet,但是SMASH並沒有像NASNet一樣進行了超參數網格搜索。
6. Revisiting
這篇文章讀下來花費了好長時間,總結一下這篇奇怪的文章:
- 提出了HyperNet生成網絡權重的想法,輸入是網絡架構表示,輸出是網絡的權重。並提出一個前提:HyperNet生成權重後的網絡和真實訓練的網絡的性能具有相關性。
- 提出了一種從memory-bank角度來看待網絡的方法,相比普通的前向反向傳播角度,一開始比較難以接受。
- 從memory-bank角度提出了比較複雜的網絡的編碼方式,用於表達網絡架構c。
- 使用MLP實現HyperNet,輸出得到網絡的權重。
想法:從個人角度出發,這篇文章想法很奇特,直接生成網絡的權重。個人也是很佩服作者的工程能力,生成網絡權重,並讓整個模型work需要多大的工程量不必多言。作者也在文章中透露了各種不work的方式,不斷地調整,比如使用weightNorm等來解決收斂問題。調整網絡的容量來想辦法儘可能提高網絡的性能表現,最終能在cifar10上得到96%的top1已經很不容易了。
這個也給了我們啓發,即便是並沒有達到SOTA,這個創新的想法、紮實的工作也可以被頂會接收。
最後作者開源了源碼,並且給源碼附上了詳細的註釋。
7. Reference
https://arxiv.org/pdf/1608.06993.pdf