李宏毅MLDS課程筆記9:Generative Adversarial Network(GAN)

NIPS 2016 Tutorial: Generative Adversarial Networks [Paper] [Video]
tips for training GAN: https://github.com/soumith/ganhacks

Basic Idea of GAN

1、最大化似然函數等價於最小化PdataPG 之間的KL散度。如果PG 取的是GMM的話,產生的image很模糊(GMM與Pdata 差太多,無法模擬Pdata )。GAN的好處之一是,用NN來定義PGPG 更一般化,可以取得更好的效果。
這裏寫圖片描述

2、現在PG 是一個NN,參數爲θ ,從input distribution 中sample得到input low-dim vector z ,經過NN之後,得到generated distribution PG . 根據θ 的不同,可以用簡單的input distribution 產生各種複雜的distribution。

這樣做的問題是難以計算likelihood。
GAN解決了這個問題,在無法計算likelihood的情況下更新θ ,使PG 更像Pdata
解決方法是從天而降一個Discriminator D, 解一個最小最大問題就得到了Generator function(NN) G ,使得input distribution經過Generator function(NN) G 得到的PGPdata 最接近。

這個最小最大問題是

G=argminGmaxDV(G,D)

其中
V(G,D)=ExPdata[logD(x)]+ExPG[log(1D(x))]

這樣定義的好處是,
maxDV(G,D)=V(G,D)=2log2+2JSD(Pdata(x)||PG(x))
衡量了PGPdata 之間的Jensen-Shannon散度(與KL散度不同,JS散度是對稱的)。

現在我們已經把maxDV(G,D) 搞定了,剩下的問題是如何求解

G=argminGmaxDV(G,D)=argminGJSD(Pdata(x)||PG(x))

方法(梯度下降):
1、初始化G0
2、得到D0 , V(G0,D0)=JSD(Pdata(x)||PG0(x))
3、用V(G,D0)θG 的梯度來更新θG ,得到G1 . 有

V(G1,D0)<V(G0,D0)

然而卻不一定有
JSD(Pdata(x)||PG1(x))=V(G1,D1)<V(G0,D0)=JSD(Pdata(x)||PG0(x))

4、得到D1 , V(G1,D1)=JSD(Pdata(x)||PG1(x))
5、用V(G,D1)θG 的梯度來更新θG ,得到G2 .有
V(G2,D1)<V(G1,D1)

然而卻不一定有
JSD(Pdata(x)||PG2(x))=V(G2,D2)<V(G1,D1)=JSD(Pdata(x)||PG1(x))

……

爲了儘量避免出現上面的“不一定有”的情況,對G每次不能更新太多。

3、實際操作中
無法用積分計算V(G,D)中的期望,通過採樣的方法得到V~ 來近似 V 。而V~ 的形式與Binary Classifier的loss function形式相同。這是符合直覺的,因爲我們要得到的Discriminator D就是一個Binary Classifier。
這裏寫圖片描述
(圖中L整體缺少一個負號)

所以實際中的算法是:
這裏寫圖片描述
其中,Pprior(z) 是自己定的簡單分佈。
之所以要把Learning D的過程重複多次是因爲每次得到的不是maxDV(G,D) ,而是maxDV(G,D) 的一個lower bound,重複多次可以使lower bound變大。
之所以把Learning G的過程只進行一次,原因就是上面說過的每次更新G的時候不能更新太多,以免JSD不降反增。

另外,實際中在Learning G的時候,目標函數也有所改變:
這裏寫圖片描述
這是因爲在開始的時候D(x)很小,此時目標函數的微分小,所以訓練慢。
改了目標函數之後,在D(x)很小的時候訓練速度變快,在D(x)接近1(我們的目標)的時候訓練速度慢下來。

Issue about Evaluating the Divergence

實際中遇到一個問題,從discriminator的loss中無法看出生成圖片的質量是否變好,因爲loss總是基本爲0,即discriminator認爲PdataPG 完全沒有overlap。
這有兩個原因。
一是discriminator過於強大,將PdataPG 的採樣點用複雜的邊界區分開。
若是要減弱discriminator的話(update次數少一點、dropout、用比較少的參數),不知道discriminator要調到什麼地步才能得到好的結果。而且,discriminator可以量JSD的前提是,discriminator可以是任何function,因此又希望discriminator能powerful一些。
二是PdataPG 本身就沒有很多overlap。
解決方法是加噪聲:
這裏寫圖片描述

Mode Collapse

Conditional GAN

待續

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章