SURE:增強不確定性估計的組合拳,快加入到你的訓練指南吧 | CVPR 2024

論文重新審視了深度神經網絡中的不確定性估計技術,並整合了一套技術以增強其可靠性。論文的研究表明,多種技術(包括模型正則化、分類器改造和優化策略)的綜合應用顯着提高了圖像分類任務中不確定性預測的準確性

來源:曉飛的算法工程筆記 公衆號

論文: SURE: SUrvey REcipes for building reliable and robust deep networks

Introduction


  深度神經網絡 (DNNs) 已成爲結構化數據預測任務中強大且適應性高的工具,但準確評估其預測的可靠性仍然是一個巨大的挑戰。在醫療診斷、機器人、自動駕駛和地球觀測系統等關鍵安全領域,過度自信的預測的決策可能會導致嚴重的後果。因此,確保基於DNN的人工智能系統的魯棒性至關重要。

  解決深度學習中的過度自信問題一直是重大研究工作的焦點,但目前很多方法的一個關鍵限制是測試場景有限,通常僅限於單個預定義任務(例如故障預測或分佈外檢測(OOD))的基準數據集。這些方法在涉及更復雜的現實情況時(如數據損壞、標籤噪聲或長尾類分佈等),其有效性仍很大程度上尚未得到充分探索。而且通過實驗表明,沒有一種方法能夠在不同的場景中表現一致。爲此,論文提出了一個有效解決所有這些挑戰的統一模型。

  在論文追求增強不確定性估計的過程中,論文首先檢查幾種現有方法的綜合影響,從而發現一種可以顯着改進的綜合方法。根據這些方法在模型訓練過程中的功能對進行分類:

  • 正則化和分類器:利用RegMixup正則化、正確性排名損失 (CRL) 和餘弦相似性分類器 (CSC) 等技術,這有助於增加具有挑戰性的樣本的熵。
  • 優化策略:按照FMFP的建議結合了銳度感知最小化 (SAM) 和隨機權重平均 (SWA),確保模型能夠收斂到更平坦的最小值。

  這些不同技術的協同整合最終形成了論文的新穎方法SURE,該方法利用了每個單獨組件的優勢,產生了更加穩健和可靠的模型。

  在評估SURE時,論文首先關注錯誤預測(failure prediction),這是評估不確定性估計的關鍵任務。結果表明,SURE始終優於部署單獨技術的模型。這種卓越的性能在CIFAR10CIFAR-100Tiny-ImageNet等各種數據集以及ResNetVGGDenseNetWideResNetDeiT等各種模型架構中都很明顯。值得注意的是,SURE甚至超越了OpenMix,這是一種利用額外OOD數據的方法。通過將SURE直接應用到現實場景中,無需或只進行很少的特定於任務的調整,進一步見證了在爲模型帶來魯棒性方面的有效性。具體來說,現實世界的挑戰包括CIFAR10-C中的數據損壞、Animal-10NFood-101N中的標籤噪聲以及CIFARLT中的類分佈傾斜。在這些背景下,SURE取得的結果要麼優於最新的方法,要麼與最新的方法相當。SUREFood-101N上達到了 88.0% 的令人印象深刻的準確率,顯着超過了之前最先進的方法Jigsaw-ViT,該方法通過使用額外的預訓練數據達到了 86.7% 的準確率,這證明了SURE在處理複雜的現實數據挑戰方面的卓越能力。

  本文的主要貢獻總結如下:

  • 實驗證明現有方法在應對各種現實挑戰時並不總能表現出色,需要更可靠、更穩健的方法來處理現實世界數據的複雜性。
  • 提出用於魯棒的不確定性估計的新穎方法SURE,結合模型正則化、分類器和優化策略等多種技術所實現的協同效應。在SURE方法下訓練的模型在故障預測方面始終比在各種數據集和模型架構中部署單獨技術的模型取得更好的性能。
  • 直接應用於現實場景時,SURE始終表現出至少與最先進的方法相當的性能。

Methods


  如圖 2 所示,SURE旨在通過兩個方面訓練可靠且魯棒的DNN:i)增加難樣本的熵; ii) 在優化過程中強制尋找平坦極值(flat minima)。

  定義 \(\{(\mathbf{x}_{i},\mathbf{y}_{i})\}_{i=1}^{N}\) 表示數據集,其中 \(\mathbf{x}_{i}\) 是輸入圖像,\(\mathbf{y}_{i}\) 是其標籤,\(N\) 是樣本數。

SURE中增加難樣本熵的方法由三個部分組成:

  • 增加RegMixup正則化 \(\mathcal{L}_{mix}\),通過數據增強添加難樣本。
  • 增加正確性排名損失 \(\mathcal{L}_{crl}\),通過將實例的置信度與正確預測次數比例進行排序對齊來正則化類概率。
  • 在分類的交叉熵損失 \({\mathcal{L}}_{ce}\)使用餘弦相似度分類器(CSC)的結果作爲輸入,可以更好地表達難樣本。

  此外,爲了平坦極值,在優化過程中使用銳度感知最小化 (SAM) 和隨機權重平均 (SWA)。

Increasing entropy for hard samples

  • Total loss

  如上所述,SURE的目標函數由三部分組成,表示爲:

\[\mathcal{L}_{total}=\mathcal{L}_{ce}+\lambda_{mix}\mathcal{L}_{mix}+\lambda_{crl}\mathcal{L}_{crl} \quad\quad (1) \]

  • RegMixup regularization

Mixup是一種廣泛用於圖像分類的數據增強方法。

  給定兩個輸入目標對 \((\mathbf{x}_{i},\mathbf{y}_{i})\)\((\mathbf{x}_{j},\mathbf{y}_{j})\),通過線性插值來獲得增強樣本 \((\tilde{\mathbf{x}}_{i}, {\tilde{\mathbf{y}}}_{i})\)

\[\tilde{{\bf x}}_{i}=m{\bf x}_{i}+(1-m){\bf x}_{j},\quad\tilde{{\bf y}}_{i}=m{\bf y}_{i}+(1-m){\bf y}_{j} \quad\quad (2) \]

  其中 \(m\) 表示混合係數,遵循Beta分佈:

\[m\sim\mathrm{Beta}(\beta,\beta),~~~\beta\in(0,\infty) \quad\quad (3) \]

RegMixup正則化 \(\mathcal{L}_{mix}\) 計算增強樣本的損失值:

\[\mathcal{L}_{mix}(\tilde{\bf x}_{i},\tilde{\bf y}_{i})=\mathcal{L}_{ce}(\tilde{\bf x}_{i},\tilde{\bf y}_{i}) \quad\quad (4) \]

  設置 \(\beta=10\),確保兩個樣本高度混合。

  與RegMixup類似,將 \(\mathcal{L}_{mix}\) 作爲附加正則化器,與 \((\mathbf{x}_{i},\mathbf{y}_{i})\) 上的原始交叉熵損失 \(\mathcal{L}_{ce}\) 一起使用。 較高的 \(\beta\) 值會導致樣本嚴重混合,促使模型在大量的插值樣本上表現出高熵,增加訓練的挑戰性。

  • Correctness ranking loss

  正確性排名損失鼓勵DNN將模型的置信度與訓練期間收集的正確預測比例信息保持一致(即經常預測正確的圖像,其置信度也應該高於不經常預測正確的圖像)。

  對於兩個輸入圖像 \(\mathbf{x}_{i}\)\(\mathbf{x}_{j}\)\(\mathcal{L}_{crl}\) 的定義爲:

\[{\mathcal{L}}_{crl}(\mathbf{x}_{i},\mathbf{x}_{j})=\operatorname*{max}(0,|c_{i}-c_{j}|-\operatorname{sign}(c_{i}-c_{j})(\mathbf{s}_{i}-\mathbf{s}_{j})) \quad\quad (5) \]

  其中 \(c_{i}\)\(c_{j}\) 表示訓練期間 \(\mathbf{x}_{i}\)\(\mathbf{x}_{j}\) 被正確預測的比例,\(\mathbf{s}_{i}\)\(\mathbf{s}_{j}\) 表示 \(\mathbf{x}_{i}\)\(\mathbf{x}_{j}\) 的置信度得分,即softmax得分,sign表示符號函數。

\(\mathcal{L}_{crl}\) 旨在將置信度得分與正確性統計數據對齊,難樣本在訓練過程中不太可能被正確預測,因此鼓勵其具有較低的置信度,從而具有較高的熵來進行反向更新。

  • Cosine Similarity Classifier (CSC)

CSC通過簡單地用餘弦分類器替換最後一個線性層,在少樣本分類中有不錯效果。簡單而言就是每個類學習一個原型向量,將其與圖像的特徵網絡輸出進行餘弦相似計算,將結果作爲預測分數。

  對於圖像 \(\mathbf{x}_{i}\) ,分類向量中對應 \(k\) 類的單元表示爲 \(\mathbf{s}_{i}^{k}\) ,其定義如下:

\[\mathrm{s}_{i}^{k}=\tau\cdot\mathrm{cos}(f_{\theta}(\mathbf{x}_{i}),w^{k})=\tau\cdot\frac{f_{\theta}(\mathbf{x}_{i})}{||f_{\theta}(\mathbf{x}_{i})||_{2}}\cdot\frac{w^{k}}{||w^{k}||_{2}}, \quad\quad (6) \]

  其中 \(\tau\) 是溫度超參數,\(f_{\theta}\)\(\theta\) 參數化的DNN網絡,用於提取輸入圖像的特徵,\(w^{k}\) 代表第 \(k\) 類的原型向量。

CSC鼓勵分類器關注從輸入圖像提取的特徵向量與類原型向量之間的方向對齊,這使得它在概念上不同於傳統的線性分類器。傳統的線性分類器中關注點積得出的幅值(用於進行softmax),而CSC僅關注其方向是否一致。CSC的一個主要好處是能夠更好地處理難樣品,將難樣本視爲與多個類原型向量在餘弦角度相等,從而比使用點積的傳統線性分類器提供更有效的可解釋性和潛在更高的熵。

Flat minima-enforced optimization

  論文聯合採用銳度感知最小化(SAM)和隨機權重平均(SWA)來增強平面最小值。

  • Sharpness-Aware Minimization (SAM)

  由於參數量巨大,深度模型存在較多的局部極值,而優化過程就是在尋找其中一個極值。一般認爲,平坦的極值比尖銳的極值的泛化能力更強。爲此,SAM通過尋找鄰域平坦的參數來增強模型泛化能力,從而使DNN具有一致的小損失,避免陷入尖銳的局部極值。

  對於論文的目標函數 \({\mathcal{L}}_{total}\)DNN參數 \({\boldsymbol{\theta}}\)SAM優化器尋求滿足以下公式的 \(\theta\)

\[\underset{\theta}{\mathrm{min}}\underset{||\epsilon||_2\leq\rho}{\mathrm{max}} \mathcal{L}_{total}(\theta+\epsilon) \quad\quad(7) \]

  其中 \(\epsilon\) 是擾動向量,\(\rho\) 是論文尋求最小化損失銳度的鄰域大小。

SAM算法在 \(\ell_2\) 範數小於 \(\rho\) 的範圍內尋找使損失最大化的擾動向量 \(\epsilon\)(此過程需要基於 \(\theta\) 產生的梯度進行計算),然後基於 \(\theta + \epsilon\) 產生的新梯度反向更新模型參數 \(\theta\),交替進行上面兩個步驟來最小化擾動損失。

  • Stochastic Weight Averaging (SWA)

SWA通過在訓練過程中平均模型權重來提高DNN的泛化能力。

  從標準訓練階段開始,SWA開始對後續每個週期的權重進行平均,權重更新爲:

\[\theta_{\mathrm{SWA}}=\frac{1}{T}\sum_{t=1}^{T}\theta_{t} \quad\quad(8) \]

  其中 \(\theta_{t}\) 表示 \(t\) 週期時的模型權重,\(T\) 是應用SWA的週期總數。

Implementation details

  使用以隨機梯度下降(SGD)作爲基礎優化器的SAM進行訓練,動量爲 0.9,初始學習率爲 0.1,權重衰減爲 5e-4,採用餘弦退火學習率策略,數據批次大小爲128。總共訓練 200 個週期,SWA起始週期設置爲 120,將SWA的學習率設置爲 0.05,以增強訓練的有效性和模型魯棒性。設置公式 (3) 中的 \(\beta\) = 10 以進行混合數據增強,所有超參數(包括 \(\lambda_{mix}\)\(\lambda_{crl}\)\(\tau\))均根據驗證集表現上進行調整。

  在對ImageNet預訓練模型DeiT-Base進行微調時,設置學習率爲 0.01,在 50 個週期內權重衰減爲 5e-5,SWA開始週期爲 1,學習率爲 0.004。

Experiments


  表 1 中展示了CIFAR10CIFAR100Tiny-ImageNet上的故障預測結果。

  表 2 展示了在長尾數據集CIFAR10-LTCIFAR100-LT與最先進方法比較。

  表 3 和表 4 展示了在含噪聲標籤的Animal-10NFood-101N上的 top-1 準確率。

  在實際應用中,環境條件容易頻繁變化,例如天氣從晴朗到多雲,再到下雨。對於模型來說,在這種分佈或領域偏移下保持可靠的決策能力至關重要。圖 3 展示了在偏移數據集CIFAR10-C上評估使用CIFAR10的乾淨訓練集訓練的模型的性能比較。

  論文在表 5 中分析了每個組件對SURECIFAR100上的性能貢獻。

  圖 4 中可視化了CIFAR100-LT IF=10上的置信度分佈,SURE明顯比MSPFMFP帶來更好的置信度分離。



如果本文對你有幫助,麻煩點個贊或在看唄~
更多內容請關注 微信公衆號【曉飛的算法工程筆記】

work-life balance.

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章