深度學習(增量學習)——GAN在增量學習中的應用(文獻綜述)

前言

我將看過的增量學習論文建了一個github庫,方便各位閱讀地址

持續學習的目的是解決災難性遺忘,當前持續學習(lifelong learning)的研究主要集中在圖像分類這一基礎任務上。圖像分類任務出現災難性遺忘(Catastrophic forgetting),其根源在於TT時刻的分類模型沒有TT時刻之前的圖像數據,意味着需要在沒有輸入分佈的前提下對TT時刻之前的數據進行分類,爲了還原出輸入圖像的分佈,目前有研究開始使用生成對抗模型(Generative Adversarial Nets),原因在於GAN可以進行概率分佈的變化,可以將隱空間中的概率分佈變化爲訓練圖像的概率分佈

​如果僅僅利用TT時刻的數據finetuning GAN,則GAN也會出現災難性遺忘,如下圖所示,將MNIST數據集分爲10個任務,每個任務GAN只學習生成一類數字,利用condition GAN在MNIST數據集上進行持續學習,condition GAN的輸入由類別label、隱空間變量zz組成,可以依據類別label生成對應類別的圖像,訓練完畢後,生成的圖片均爲9,即出現災難性遺忘。

在這裏插入圖片描述

​爲了解決GAN上的災難性遺忘,研究人員採取了一系列措施,大致分爲兩類:

  • 使用記憶重放(Memory replay)機制。
  • Regularization,即在損失函數中添加正則項,來防止災難性遺忘。

Memory replay

​如[6],[1],[3],[5]所示,Memory replay在訓練TT時刻的GAN時,讓T1T-1時刻的GAN生成一批舊類別圖片,與TT時刻的新類別圖片混合在一起,訓練TT時刻的GAN。[1]存儲了部分舊類別的圖片,與生成的舊類別圖片一起訓練TT時刻的GAN。爲了確保每一箇舊類別都具有Memory replay生成的圖片,目前主要採用兩類方式:

  • 使用Uncondition GAN,用T1T-1時刻的分類器對生成的圖片進行分類。爲了保證圖像質量,在用分類器判斷完圖像屬於AA類別後,AA類別得分高於一個閾值θ\theta,纔會用於訓練TT時刻的GAN。
  • 使用Condition GAN,可以依據label生成對應類別的圖片。

Memory replay的缺陷

​ Memory replay的缺點很明顯,若T1T-1時刻的GAN生成的圖片質量極差,無法反映圖像真實的概率分佈,會影響TT時刻GAN的訓練,如下圖所示,Task 3訓練完畢後,生成的iris圖像質量較差,直接導致Task 4、Task 5生成的iris圖像質量較差。
在這裏插入圖片描述

Regularization

​ Regularization即通過在損失函數中添加正則項,來尋求一個合適的解,如下圖紅線所示,通過合適的正則項,可以尋找到即可以較好生成AA任務圖像,又可以生成較好BB任務圖像的解。

在這裏插入圖片描述

​ 目前研究採用的Regularization大致可以分爲兩類:

  • EWC Regularization
  • L2、L1距離

Regularization 方式一:EWC Regularization

​ [2]在Generator的loss中添加了EWC Regularization,計算公式如下圖,第一項爲傳統的Generator loss,第二項爲EWC Regularization。其中FiF_i爲Fisher information
在這裏插入圖片描述
在這裏插入圖片描述

​ EWC Regularization可以從概率角度判斷Generator中,哪一部分參數對於生成舊類別圖像更爲重要,通過限制這類參數發生太大改變,來防止GAN發生災難性遺忘。

Regularization 方式二:L2、L1距離

​ L2、L1距離類似於knowledge distillation,將T1T-1時刻Generator的知識蒸餾到TT時刻的Generator。如下圖所示,[7] [6]將同一隱變量輸入到T1T-1時刻與TT時刻的Generator,將得到的兩張圖片做L1距離或L2距離(對應圖中的LRAL_{RA}),作爲Generator的正則項,若LRAL_{RA}爲L2距離,此時Generator的損失函數變爲式12.0

在這裏插入圖片描述
在這裏插入圖片描述

Regularization存在的問題

​ Regularization要求TT時刻的GAN,對同一輸入,在訓練階段生成的圖片既要與新任務圖像一致(否則無法欺騙Discriminator),又要與舊任務一致(否則Regularization的值會很高),這是矛盾的。

​ 現有的方案通過condition GAN解決上述問題,如下圖所示,假設每次只學習一個新類別,cc表示類別的label,zz表示隱變量,StS_t表示第tt個類別的數據,下圖左半部分表示只有舊類別的隱變量會參與Regularization的計算,右半部分表示只有新類別的數據會用於訓練Discriminator,如此一來,爲了讓Regularization項變小,對於舊類別的隱變量,TT時刻Generator生成的圖片應該與T1T-1時刻Generator生成的圖片儘可能一致,對於新任務,TT時刻的Generator生成的圖片需要儘可能與新任務圖像一致,纔可以欺騙Discriminator。

在這裏插入圖片描述

參考文獻

[1] Amanda Rios ,Laurent Itti. Closed-Loop Memory GAN for Continual Learning. In IJCAI, 2019

[2] Ari Seff, Alex Beatson, Daniel Suo,,Han Liu. Continual Learning in Generative Adversarial Nets.2017

[3] Hanul Shin,Jung Kwon Lee,Jaehong Kim,Jiwon Kim. Continual Learning with Deep Generative Replay. In NIPS, 2017

[4] Yue Wu,Yinpeng Chen,Lijuan Wang,Yuancheng Ye,Zicheng Liu,Yandong Guo,Zhengyou Zhang2,Yun Fu.Incremental Classifier Learning with Generative Adversarial Networks.2018

[5] Ye Xiang,Ying Fu,Pan Ji,Hua Huang.Incremental Learning Using Conditional Adversarial Networks.In ICCV, 2019

[6] henshen Wu,Luis Herranz,Xialei Liu,Yaxing Wang,Joost van de Weijer, Bogdan Raducanu. Memory replay GANs- Learning to generate images from new categories without forgetting. In NIPS, 2018

[7]Mengyao Zhai, Lei Chen,Fred Tung,Jiawei He,Megha Nawhal, Greg Mori.Lifelong GAN: Continual Learning for Conditional Image Generation.In ICCV,2019

在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章