Learning to Reweight Examples for Robust Deep Learning
Abstract
面對樣本不平衡問題和標籤噪聲等問題,之前是通過regularizers或者reweight算法,但是需要不斷調整超參取得較好的效果。本文提出了meta-learning的算法,基於梯度方向調整權重。具體做法是需要保證獲得一個足夠乾淨的小樣本數據集,每經過一輪batch大小的訓練就基於當前更新的權重,執行meta gradient descent step來最小化在這個乾淨無偏差的驗證集上的loss。這個方法避免了額外的超參調整,在樣本不平衡和標籤噪聲等問題上可以有很好的效果,所需要的僅僅是一個很小數量的乾淨的驗證集。
Related Work
在解決樣本問題上的工作:
-
訓練集樣本權重分配:
AdaBoost:尋找難例來訓練分類器。
難例挖掘: 下采樣多數樣本,挖掘最難的樣本
Focal Loss:不同樣本添加不同權重,困難樣本權重更大 -
outliers和noise processes:
有些方法是先學習簡單樣本在學習困難樣本
部分工作是去研究如何更好地初始化網絡參數 -
直接對樣本數據集下手,re-sample之類的
在最近的meta-learning中,很多都在探索使用validation loss作爲meta-objective,本文算法的區別是沒有額外的超參,並且避免了成本較高的離線訓練。
Learning to Reweight Examples
本文的模型看做online approximation而不是meta-learning objective,這樣就可以處理任何常規的監督學習。
文章給出了具體實現並且有理論保證,收斂率爲
3.1 From a meta-learning objective to online approximation
爲輸入-標籤對,爲訓練集,假設爲一個很小的乾淨無偏差的驗證集,其中. 表示驗證集,表示第個數據;同時假設訓練集是包含驗證集的,如果不包含,就把驗證集加入到訓練集中,從而使得訓練過程中能夠利用更多信息。
用表示網絡模型,爲模型參數,定義爲loss函數,其中。
在一般的訓練中,我們希望最小化訓練集上的期望loss,也就是,其中每一個輸入樣本權重相等,表示輸入數據對應的loss。本文希望通過最小化weighted loss來學習去re-weight 輸入。weighted loss如下:
其中一開始未知,可以被理解成訓練超參數,基於在驗證集上的表現來最優化:
這裏需要保證 對所有的,因爲最小化負的training loss可能會導致一些不穩定的情況。
Note:公式(1)(2)實際上就是最小化training loss同時還得保證這時候的權重在驗證集的loss也最小。
Online Approximation 計算最優的需要兩層嵌套的最優化循環,並且單個循環成本很高。本文方法的目的是通過一層優化循環來在線調整。每一個training iteration中,首先只在training loss平面上檢查部分訓練樣本的下降方向,然後根據和validation loss平面下降方向的相似性對樣本進行reweighting。
大多數深度網絡都有用SGD或者改進版來優化loss。對於training過程的每一步,採樣出mini-batch個訓練樣本,是mini-batch size 且。然後根據mini-batch上期望loss的下降方向來更新參數。
最普通的SGD如下:
其中是step size。
本文想要探究在第個training step,什麼因素影響了訓練樣本在validation set上的性能。
於是對mini-batch 中的每個樣本,加上一個權重擾動
Note: 目前這樣和weighted loss看起來沒什麼區別
然後尋找在第步能夠最小化驗證集loss的最優的
這樣計算依然成本很高,爲了更容易的在第步估計,本文在mini-batch大小的驗證集()上採用一步梯度下降(single gradient descent step),並對輸出進行矯正保證權重均非負。
其中是在上的下降步長。
爲了匹配原本的訓練步長,實際上,我們可以考慮在一個batch中把所有樣本的權重標準化,使他們的和爲1.
也就是對集合加一個強限制,使其滿足
其中是爲了防止出現在mini-batch中所有的都是0的情況,也就是如果那麼有,其他情況下都是。如果沒有批歸一化步驟,該算法可能朝着最高效的學習率的方向修改,本文的one-step方法在學習率的選擇上可能會更保守。此外,通過批歸一化,有效地取消了元學習速率參數。
3.2 Example on MLP
給了一個多層感知器的數學公式推導,感興趣可以閱讀原文
3.3 Implementation using automatic differentiation
Note:圖的流程描述下來就是首先正常前向計算,獲得loss,然後backward,依據樣本權重更新參數,獲得新的參數,然後放入驗證集forward,backward,根據validation loss 獲得樣本權重並更新。然後繼續步驟1.
詳細算法步驟如下:
使用這個reweight策略的訓練時間是正常訓練網絡的三倍(因爲有兩次forward-backward,還有一個backward-on-backward),如果想縮減時間,可以減少驗證集的batch size。
作者認爲多花時間訓練來避免不斷的調參之類的操作是值得的。
3.4 reweighted training的收斂證明
有空在補,公式太多不想看了
4 Experiments
作者做了兩個實驗,一個是在MNIST上的imbalance問題,一個是CIFAR上的noisy label問題。
作者說適用於任何深度網絡,已經有了開源的pytorch代碼,後續加到自己的工作中試一下看看效果。