隨機權值平均(Stochastic Weight Averaging,SWA)
隨機權值平均只需快速集合集成的一小部分算力,就可以接近其表現。SWA 可以用在任意架構和數據集上,都會有不錯的表現。根據論文中的實驗,SWA 可以得到我之前提到過的更寬的極小值。在經典認知下,SWA 不算集成,因爲在訓練的最終階段你只得到一個模型,但它的表現超過了快照集成,接近 FGE(多個模型取平均)。
左圖:W1、W2、W3分別代表3個獨立訓練的網絡,Wswa爲其平均值。中圖:WSWA 在測試集上的表現超越了SGD。右圖:WSWA 在訓練時的損失比SGD要高。
結合 WSWA 在測試集上優於 SGD 的表現,這意味着儘管 WSWA 訓練時的損失較高,它的泛化性更好。
SWA 的直覺來自以下由經驗得到的觀察:每個學習率週期得到的局部極小值傾向於堆積在損失平面的低損失值區域的邊緣(上圖左側的圖形中,褐色區域誤差較低,點W1、W2、3分別表示3個獨立訓練的網絡,位於褐色區域的邊緣)。對這些點取平均值,可能得到一個寬闊的泛化解,其損失更低(上圖左側圖形中的 WSWA)。
下面是 SWA 的工作原理。它只保存兩個模型,而不是許多模型的集成:
- 第一個模型保存模型權值的平均值(WSWA)。在訓練結束後,它將是用於預測的最終模型。
- 第二個模型(W)將穿過權值空間,基於週期性學習率規劃探索權重空間。
SWA權重更新公式
在每個學習率週期的末尾,第二個模型的當前權重將用來更新第一個模型的權重(公式如上)。因此,在訓練階段,只需訓練一個模型,並在內存中儲存兩個模型。預測時只需要平均模型,基於其進行預測將比之前描述的集成快很多,因爲在那種集成中,你需要使用多個模型進行預測,最後再進行平均。
方法實現:
論文的作者自己提供了一份 PyTorch 的實現 :
此外,基於 fast.ai 庫的 SWA 可見 :
Add Stochastic Weight Averaging by wdhorton · Pull Request #276 · fastai/fastaigithub.com