GIT:斯坦福大學提出應對複雜變換的不變性提升方法 | ICLR 2022

論文對長尾數據集中的複雜變換不變性進行了研究,發現不變性在很大程度上取決於類別的圖片數量,實際上分類器並不能將從大類中學習到的不變性轉移到小類中。爲此,論文提出了GIT生成模型,從數據集中學習到類無關的複雜變換,從而在訓練時對小類進行有效增強,整體效果不錯

來源:曉飛的算法工程筆記 公衆號

論文: Do Deep Networks Transfer Invariances Across Classes?

Introduction


  優秀的泛化能力需要模型具備忽略不相關細節的能力,比如分類器應該對圖像的目標是貓還是狗進行響應,而不是背景或光照條件。換句話說,泛化能力需要包含對複雜但不影響預測結果的變換的不變性。在給定足夠多的不同圖片的情況下,比如訓練數據集包含在大量不同背景下的貓和狗的圖像,深度神經網絡的確可以學習到不變性。但如果狗類的所有訓練圖片都是草地背景,那分類器很可能會誤判房子背景中的狗爲貓,這種情況往往就是不平衡數據集存在的問題。
  類不平衡在實踐中很常見,許多現實世界的數據集遵循長尾分佈,除幾個頭部類有很多圖片外,而其餘的每個尾部類都有很少的圖片。因此,即使長尾數據集中圖片總量很大,分類器也可能難以學習尾部類的不變性。雖然常用的數據增強可以通過增加尾部類中的圖片數量和多樣性來解決這個問題,但這種策略並不能用於模仿複雜變換,如更換圖片背景。需要注意的是,像照明變化之類的許多複雜變換是類別無關的,能夠類似地應用於任何類別的圖片。理想情況下,經過訓練的模型應該能夠自動將這些不變性轉爲類無關的不變性,兼容尾部類的預測。
  論文通過實驗觀察分類器跨類遷移學習到的不變性的能力,從結果中發現即使經過過採樣等平衡策略後,神經網絡在不同類別之間傳遞學習到的不變性也很差。例如,在一個長尾數據集上,每個圖片都是隨機均勻旋轉的,分類器往往對來自頭部類的圖片保持旋轉不變,而對來自尾部類的圖片則不保持旋轉不變。
  爲此,論文提出了一種更有效地跨類傳遞不變性的簡單方法。首先訓練一個input conditioned但與類無關的生成模型,該模型用於捕獲數據集的複雜變換,隱藏了類信息以便鼓勵類之間的變換轉移。然後使用這個生成模型來轉換訓練輸入,類似於學習數據增強來訓練分類器。論文通過實驗證明,由於尾部類的不變性得到顯著提升,整體分類器對複雜變換更具不變性,從而有更好的測試準確率。

Measuring Invariance Transfer In Class-Imbalanced Datasets


  論文先對不平衡場景中的不變性進行介紹,隨後定義一個用於度量不變性的指標,最後再分析不變性與類別大小之間的關係。

Setup:Classification,Imbalance,and Invariances

  定義輸入\((x,y)\),標籤\(y\)屬於\(\{1,\cdots,C\}\)\(C\)爲類別數。定義訓練後的模型的權值\(w\),用於預測條件概率\(\tilde{P}_w(y=j|x)\),分類器將選擇概率最大的類別\(j\)作爲輸出。給定訓練集\(\{(x^{(i)}, y^{(i)})\}^N_{i=1}\sim \mathbb{P}_{train}\),通過經驗風險最小化(ERM)來最小化訓練樣本的平均損失。但在不平衡場景下,由於\(\{y^{(i)}\}\)的分佈不是均勻的,導致ERM在少數類別上表現不佳。
  在現實場景中,最理想的是模型在所有類別上都表現得不錯。爲此,論文采用類別平衡的指標來評價分類器,相當於測試分佈\(\mathbb{P}_{test}\)\(y\)上是均勻的。
  爲了分析不變性,論文假設\(x\)的複雜變換分佈爲\(T(\cdot|x)\)。對於不影響標籤的複雜變換,論文希望分類器是不變的,即預測的概率不會改變:

Measuring Learned Invariacnes

  爲了度量分類器學習不變性的程度,論文定義了原輸入和變換輸入之間的期望KL散度(eKLD):

  這是一個非負數,eKLD越低代表不變性程度就越高,對\(T\)完全不變的分類器的eKLD爲0。如果有辦法採樣\(x^{'}\sim T(\cdot|x)\),就能計算訓練後的分類器的eKLD。此外,爲了研究不變性與類圖片數量的關係,可以通過分別計算類特定的eKLD進行分析,即將公式2的\(x\)限定爲類別\(j\)所屬。
  計算eKLD的難點在於複雜變化分佈\(T\)的獲取。對於大多數現實世界的數據集而言,其複雜變化分佈是不可知的。爲此,論文通過選定複雜分佈來生成數據集,如RotMNIST數據集。與數據增強不同,這種生成方式是通過變換對數據集進行擴充,而不是在訓練過程對同一圖片應用多個隨機採樣的變換。
  論文以Kuzushiji-49作爲基礎,用三種不同的複雜變換生成了三個不同的數據集:圖片旋轉(K49-ROT-LT)、不同背景強度(K49-BG-LT)和圖像膨脹或侵蝕(K49-DIL-LT)。爲了使數據集具有長尾分佈(LT),先從大到小隨機選擇類別,然後有選擇地減少類別的圖片數直到數量分佈符合參數爲2.0的Zipf定律,同時強制最少的類爲5張圖片。重複以上操作30次,構造30個不同的長尾數據集。每個長尾數據集有7864張圖片,最多的類有4828張圖片,最小的類有5張圖片,而測試集則保持原先的不變。

  訓練方面,採用標準ERM和CE+DRS兩種方法,其中CE+DRS基於交叉熵損失進行延遲的類平衡重採樣。DRS在開始階段跟ERM一樣隨機採樣,隨後再切換爲類平衡採樣進行訓練。論文爲每個訓練集進行兩種分類器的訓練,隨後計算每個分類器每個類別的eKLD指標。結果如圖1所示,可以看到兩個現象:

  • 在不同變化數據集上,不變性隨着類圖片數減少都降低了。這表明雖然複雜變換是類無關的,但在不平衡數據集上,模型無法在類之間傳遞學習到的不變性。
  • 對於圖片數量相同的類,使用CE+DRS訓練的分類器往往會有較低的eKLD,即更好的不變性。但從曲線上看,DRS仍有較大的提升空間,還沒達到類別之間一致的不變性。

Trasnferring Invariances with Generative Models


  從前面的分析可以看到,長尾數據集的尾部類對複雜變換的不變性較差。下面將介紹如何通過生成式不變性變換(GIT)來顯式學習數據集中的複雜變換分佈\(T(\cdot|x)\),進而在類間轉移不變性。

Learning Nuisance Transformations from Data

  如果有數據集實際相關的複雜變換的方法,可以直接將其用作數據增強來加強所有類的不變性,但在實踐中很少出現這種情況。於是論文提出GIT,通過訓練input conditioned的生成模型\(\tilde{T}(\cdot|x)\)來近似真實的複雜變換分佈\(T(\cdot|x)\)

  論文參考了多模態圖像轉換模型MUNIT來構造生成模型,該類模型能夠從數據中學習到多種複雜變換,然後對輸入進行變換生成不同的輸出。論文對MUNIT進行了少量修改,使其能夠學習單數據集圖片之間的變換,而不是兩個不同域數據集之間的變換。從圖2的生成結果來看,生成模型能夠很好地捕捉數據集中的複雜變換,即使是尾部類也有不錯的效果。需要注意的是,MUNIT是非必須的,也可以嘗試其它可能更好的方法。
  在訓練好生成模型後,使用GIT作爲真實複雜變換的代理來爲分類器進行數據增強,希望能夠提高尾部類對複雜變換的不變性。給定訓練輸入\(\{(x^{(i)}, y^{(i)})\}^{|B|}_{i=1}\),變換輸入\(\tilde{x}^{(i)}\gets \tilde{T}(\cdot|x^{(i)})\),保持標籤不變。這樣的變換能夠提高分類器在訓練期間的輸入多樣性,特別是對於尾部類。需要注意的是,batch可以搭配任意的採樣方法(Batch Sampler),比如類平衡採樣器。此外,還可以有選擇地進行增強,避免由於生成模型的缺陷損害性能的可能性,比如對數量足夠且不變性已經很好的頭部類不進行增強。

  在訓練中,論文設置閾值\(K\),僅圖片數量少於\(K\)的類進行數據增強。此外,僅對每個batch的\(p\)比例進行增強。\(p\)一般取0.5,而\(K\)根據數據集可以設爲20-500,整體邏輯如算法1所示。

GIT Improves Invariance on Smaller Classes

  論文基於算法1進行了實驗,將Batch Sampler設爲延遲重採樣(DRS),Update Classifier使用交叉熵梯度更新,整體模型標記爲\(CE+DRS+GIT(all classes)\)。all classes表示禁用閾值\(K\),僅對K49數據集使用。作爲對比,Oracle則是用於構造生成數據集的真實變換。從圖3的對比結果可以看到,GIT能夠有效地增強尾部類的不變性,但同時也損害了圖片充裕的頭部類的不變性,這表明了閾值\(K\)的必要性。

Experiment


  不同訓練策略搭配GIT的效果對比。

  在GTSRB和CIFAR數據集上的變換輸出。

  CIFAR-10上每個類的準確率。

  對比實驗,包括閾值\(K\)對性能的影響,GTSRB-LT, CIFAR-10 LT和CIFAR-100 LT分別取25、500和100。這裏的最好性能貌似都比RandAugment差點,有可能是因爲論文還沒對實驗進行調參,而是直接複用了RandAugment的實驗參數。這裏比較好奇的是,如果在訓練生成模型的時候加上RandAugment,說不定性能會更好。

Conclusion


  論文對長尾數據集中的複雜變換不變性進行了研究,發現不變性在很大程度上取決於類別的圖片數量,實際上分類器並不能將從大類中學習到的不變性轉移到小類中。爲此,論文提出了GIT生成模型,從數據集中學習到類無關的複雜變換,從而在訓練時對小類進行有效增強,整體效果不錯。



如果本文對你有幫助,複雜點個贊或在看唄~
更多內容請關注 微信公衆號【曉飛的算法工程筆記】

work-life balance.

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章