[Paper Note] Learning to Reweight Examples for Robust Deep Learning

Learning to Reweight Examples for Robust Deep Learning

PAPER
CODE

Abstract

面對樣本不平衡問題和標籤噪聲等問題,之前是通過regularizers或者reweight算法,但是需要不斷調整超參取得較好的效果。本文提出了meta-learning的算法,基於梯度方向調整權重。具體做法是需要保證獲得一個足夠乾淨的小樣本數據集,每經過一輪batch大小的訓練就基於當前更新的權重,執行meta gradient descent step來最小化在這個乾淨無偏差的驗證集上的loss。這個方法避免了額外的超參調整,在樣本不平衡和標籤噪聲等問題上可以有很好的效果,所需要的僅僅是一個很小數量的乾淨的驗證集。

Related Work

在解決樣本問題上的工作:

  • 訓練集樣本權重分配:
    AdaBoost:尋找難例來訓練分類器。
    難例挖掘: 下采樣多數樣本,挖掘最難的樣本
    Focal Loss:不同樣本添加不同權重,困難樣本權重更大

  • outliers和noise processes:
    有些方法是先學習簡單樣本在學習困難樣本
    部分工作是去研究如何更好地初始化網絡參數

  • 直接對樣本數據集下手,re-sample之類的

在最近的meta-learning中,很多都在探索使用validation loss作爲meta-objective,本文算法的區別是沒有額外的超參,並且避免了成本較高的離線訓練。

Learning to Reweight Examples

本文的模型看做online approximation而不是meta-learning objective,這樣就可以處理任何常規的監督學習。
文章給出了具體實現並且有理論保證,收斂率爲O(1/ϵ2)O\left(1 / \epsilon^{2}\right)

3.1 From a meta-learning objective to online approximation

(x,y)(x,y)爲輸入-標籤對,{(xi,yi),1iN}\left\{\left(x_{i}, y_{i}\right), 1 \leq i \leq N\right\}爲訓練集,假設{(xiv,yiv),1iM}\left\{\left(x_{i}^{v}, y_{i}^{v}\right), 1 \leq i \leq M\right\}爲一個很小的乾淨無偏差的驗證集,其中MNM \ll N. vv表示驗證集,ii表示第ithi^{th}個數據;同時假設訓練集是包含驗證集的,如果不包含,就把驗證集加入到訓練集中,從而使得訓練過程中能夠利用更多信息。

Φ(x,θ)\Phi(x, \theta)表示網絡模型,θ\theta爲模型參數,定義C(y^,y)C(\hat{y}, y)爲loss函數,其中y^=Φ(x,θ)\hat{y}=\Phi(x, \theta)

在一般的訓練中,我們希望最小化訓練集上的期望loss,也就是1Ni=1NC(y^i,yi)=1Ni=1Nfi(θ)\frac{1}{N} \sum_{i=1}^{N} C\left(\hat{y}_{i}, y_{i}\right)=\frac{1}{N} \sum_{i=1}^{N} f_{i}(\theta),其中每一個輸入樣本權重相等,fi(θ)f_{i}(\theta)表示輸入數據xix_i對應的loss。本文希望通過最小化weighted loss來學習去re-weight 輸入。weighted loss如下:
θ(w)=argminθi=1Nwifi(θ)(1) \theta^{*}(w)=\arg \min _{\theta} \sum_{i=1}^{N} w_{i} f_{i}(\theta) \tag{1}
其中wiw_i一開始未知,{wi}i=1N\left\{w_{i}\right\}_{i=1}^{N}可以被理解成訓練超參數,基於ww在驗證集上的表現來最優化ww
w=argminw,w01Mi=1Mfiv(θ(w))(2) w^{*}=\arg \min _{w, w \geq 0} \frac{1}{M} \sum_{i=1}^{M} f_{i}^{v}\left(\theta^{*}(w)\right) \tag{2}
這裏需要保證wi0w_i \geq 0 對所有的ii,因爲最小化負的training loss可能會導致一些不穩定的情況。

Note:公式(1)(2)實際上就是最小化training loss同時還得保證這時候的權重在驗證集的loss也最小。

Online Approximation 計算最優的wiw_i需要兩層嵌套的最優化循環,並且單個循環成本很高。本文方法的目的是通過一層優化循環來在線調整ww。每一個training iteration中,首先只在training loss平面上檢查部分訓練樣本的下降方向,然後根據和validation loss平面下降方向的相似性對樣本進行reweighting。

大多數深度網絡都有用SGD或者改進版來優化loss。對於training過程的每一步tt,採樣出mini-batch個訓練樣本{(xi,yi),1in}\left\{\left(x_{i}, y_{i}\right), 1 \leq i \leq n\right\}nn是mini-batch size 且nNn \ll N。然後根據mini-batch上期望loss的下降方向來更新參數。
最普通的SGD如下:
θt+1=θtα(1ni=1nfi(θt))(3) \theta_{t+1}=\theta_{t}-\alpha \nabla\left(\frac{1}{n} \sum_{i=1}^{n} f_{i}\left(\theta_{t}\right)\right) \tag{3}
其中α\alpha是step size。

本文想要探究在第tt個training step,什麼因素影響了訓練樣本ii在validation set上的性能。
於是對mini-batch 中的每個樣本,加上一個權重擾動ϵi\epsilon_{i}
fi,ϵ(θ)=ϵifi(θ)(4) f_{i, \epsilon}(\theta)=\epsilon_{i} f_{i}(\theta) \tag{4}
Note: 目前這樣和weighted loss看起來沒什麼區別
θ^t+1(ϵ)=θtαi=1nfi,ϵ(θ)θ=θt(5) \hat{\theta}_{t+1}(\epsilon)=\theta_{t}-\left.\alpha \nabla \sum_{i=1}^{n} f_{i, \epsilon} (\theta)\right|_{\theta=\theta_{t}} \tag{5}
然後尋找在第tt步能夠最小化驗證集lossfvf^{v}的最優的ϵ\epsilon^*
ϵt=argminϵ1Mi=1Mfiv(θt+1(ϵ))(6) \epsilon_{t}^{*}=\arg \min _{\epsilon} \frac{1}{M} \sum_{i=1}^{M} f_{i}^{v}\left(\theta_{t+1}(\epsilon)\right) \tag{6}
這樣計算依然成本很高,爲了更容易的在第tt步估計wiw_i,本文在mini-batch大小的驗證集(ϵt\epsilon_t)上採用一步梯度下降(single gradient descent step),並對輸出進行矯正保證權重均非負。
ui,t=ηϵi,t1mj=1mfjv(θt+1(ϵ))ϵi,t=0(7) u_{i, t}=-\left.\eta \frac{\partial}{\partial \epsilon_{i, t}} \frac{1}{m} \sum_{j=1}^{m} f_{j}^{v}\left(\theta_{t+1}(\epsilon)\right)\right|_{\epsilon_{i, t}=0} \tag{7}
w~i,t=max(ui,t,0)(8) \tilde{w}_{i, t}=\max \left(u_{i, t}, 0\right) \tag{8}
其中η\eta是在ϵ\epsilon上的下降步長。

爲了匹配原本的訓練步長,實際上,我們可以考慮在一個batch中把所有樣本的權重標準化,使他們的和爲1.
也就是對集合加一個強限制,使其滿足{w:w1=1}{0}\{w:\|w\|_{1}=1 \} \cup \{0\}
wi,t=w~i,t(jw~j,t)+δ(jw~j,t)(9) w_{i, t}=\frac{\tilde{w}_{i, t}}{\left(\sum_{j} \tilde{w}_{j, t}\right)+\delta\left(\sum_{j} \tilde{w}_{j, t}\right)} \tag{9}
其中δ()\delta(\cdot)是爲了防止出現在mini-batch中所有的wiw_i都是0的情況,也就是如果a=0a=0那麼有δ(a)=1\delta(a)=1,其他情況下都是δ(a)=0\delta(a)=0。如果沒有批歸一化步驟,該算法可能朝着最高效的學習率的方向修改,本文的one-step方法在學習率的選擇上可能會更保守。此外,通過批歸一化,有效地取消了元學習速率參數η\eta

3.2 Example on MLP

給了一個多層感知器的數學公式推導,感興趣可以閱讀原文

3.3 Implementation using automatic differentiation

在這裏插入圖片描述
Note:圖的流程描述下來就是首先正常前向計算,獲得loss,然後backward,依據樣本權重更新參數,獲得新的參數θ^\hat{\theta},然後放入驗證集forward,backward,根據validation loss 獲得樣本權重並更新。然後繼續步驟1.

詳細算法步驟如下:
在這裏插入圖片描述
使用這個reweight策略的訓練時間是正常訓練網絡的三倍(因爲有兩次forward-backward,還有一個backward-on-backward),如果想縮減時間,可以減少驗證集的batch size。
作者認爲多花時間訓練來避免不斷的調參之類的操作是值得的。

3.4 reweighted training的收斂證明

有空在補,公式太多不想看了

4 Experiments

作者做了兩個實驗,一個是在MNIST上的imbalance問題,一個是CIFAR上的noisy label問題。
在這裏插入圖片描述
在這裏插入圖片描述
在這裏插入圖片描述

作者說適用於任何深度網絡,已經有了開源的pytorch代碼,後續加到自己的工作中試一下看看效果。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章