上篇文章見:GAN及其變體C_GAN,infoGAN,AC_GAN,DC_GAN(一)
WGAN(Wasserstein GAN)
論文:
前一篇論文並沒有介紹一個算法,而是給出GAN動態訓練的理論理解,一堆公式定理,傷不起,傷不起呀,幸虧有《令人拍案叫絕的Wasserstein GAN》,講解地很詳細,也很通俗易懂,很棒。
在介紹WGAN之前,先介紹兩個數學概念KL散度(Kullback-Leibler divergence)和JS散度(Jensen-Shannon divergence):
KL散度又稱相對熵,信息散度和信息增益,度量兩個概率分佈的匹配程度,兩個分佈差異越大,KL散度就越大,公式定義如下:
有時候將KL散度稱爲KL距離,但是它並不滿足距離的性質,首先KL散度是不對稱的,再次KL散度不滿足三角不等式
JS散度是相似度衡量指標,衡量了兩個概率分佈的相似度,如果,完全相同,那麼JS爲0,如果完全不相同,則取值爲1。假設有兩個分佈和,其JS散度公式爲:
JS基於KL散度的變體,解決了KL散度非對稱的問題,JS是對稱的,並且取值範圍爲[0, 1]。
下面開始進入正題,原始GAN的損失函數爲:
分成兩步進行計算,首先,訓練判別器D最大化
可以得出最簡化的最優判別器爲:
接下來將(2)式代入代入(1)式,就可以得到,所以,原始GAN的優化目標經過一定的數學推導後,可以等價於當判別器最優的時候,最小化真實分佈和生成分佈之間JS散度。然而由於和幾乎不可能有不可忽略的重疊,所以無論它們相距多遠,JS散度都是常數log2,最終導致生成器的梯度近似於0,梯度消失。即使是對接近於最優的判別器來說,生成器有很大機會面臨梯度消失的問題。總結來說,判別器訓練的太好,生成器梯度消失,使得生成器loss降不下去;判別器訓練的不好,生成器梯度不準,四處亂跑。所以只有訓練器訓練的將將好的時候,才能達到要求,然而這個火候很不好把握,即使是同一輪的訓練的前後不同階段,所以GAN纔會不好訓練。
基於等價優化的衡量標準JS散度不合理。WGAN中提出了Wasserstein距離,Wasserstein距離,又稱Earth-Mover(EM)距離,度量兩個概率分佈之間的距離。定義如下:
其中是聯合分佈,對於每一個可能的聯合分佈而言,從中取樣一個真實樣本x和一個生成樣本y, ,計算出這兩個樣本之間的距離,所以可以計算出該聯合分佈下樣本的期望值,在所有可能的聯合分佈中能夠對這個期望值取得下限,就定義爲Wasserstein距離。Wasserstein距離相比KL散度,JS散度的優越性在於,即使兩個分佈沒有重疊,Wasserstein距離仍然能夠反映它們的遠近。論文中給出KL散度和JS散度是突變的,要麼最大或者最小,Wasserstein距離卻是平滑的,如果使用梯度下降法優化這個參數,前兩者是提供不了梯度的,Wasserstein距離卻是可以的。
上述的Wasserstein公式中的下確界沒法直接求解,又經過一系列數學推導,最後變換成如下形式:
由於用到K-Lipschitz函數,公式限制條件爲,在此,作者採取了一個簡單的做法,每次參數更新後,限制神經網絡的所有參數的範圍爲。
至此,就構造了一個含參數,最後一層不是非線性激活函數的判別器網絡,在限制不超過某個範圍的條件下,使得儘可能最大,此時L就會近似真實分佈與生成分佈之間的Wasserstein距離,接下來生成器要最小化Wasserstein距離,可以最小化L。而且由於Wasserstein距離的優良特性,不用擔心生成器梯度消失的問題。如下圖是WGAN的算法過程
由於原始GAN的判別器是一個二分類問題,而WGAN中的判別器是去近似擬合Wasserstein 距離,二分類問題就變成了迴歸任務,所以需要將最後一個的sigmoid函數拿掉。既然Wasserstein 具體可以量化真實分佈和生成分佈之間的距離,可以作爲訓練進程的判別標準,其值越小,表示GAN訓練得越好。
總結來說,Wasserstein GAN(WGAN)主要貢獻在於:
- 徹底解決了GAN訓練不穩定的問題,不再需要小心平衡生成器和判別器的訓練程度
- 基本解決了collapse mode問題,確保了生成樣本的多樣性
- 訓練過程中終於有一個像交叉熵或者準確率這樣的指標指示訓練的進程,Wasserstein 距離,這個數值越小代表GAN訓練得越好,代表生成器產生的圖片質量越高
- 不需要精心設計網絡架構,最簡單的多層全連接網絡就可以做到
improved WGAN
上述介紹的WGAN雖然能夠解決GAN模型訓練時的不穩定性問題,但是參數的修剪策略(weight clipping)會導致最優化困難。由於限制權重的範圍,使得嘗試獲得最大梯度範數的神經網絡架構常常以學得簡單的函數而告終。也就是說,通過權重剪枝實現K-Lipshitz將會趨向更簡單的函數,爲了展示這個結論,使得真實分佈加上unit-variance高斯噪聲作爲生成分佈,作者在幾個toy分佈上訓練WGAN得到最優值,如下圖所示,top是使用weight clipping的結果,bottom是gradient penalty的結果,gradient penalty是爲weight clipping出現的問題提出的改進策略。
如果一個可微函數爲1-Lipschtiz,當且僅當它的gradient with norm <= 1, 那麼我們就可以考慮直接限制input到output的gradient norm,, 因此,在原有critic loss上添加對於來自隨機的樣本gradient norm的懲罰項,新的目標函數如下:
那麼是什麼呢?就是從data distribution中sample一個點,從generator distribution中sample一個點,然後連接這兩個點成一條直線,從這條直線上sample一個點,作爲中的點。論文給出了利用weight clipping和gradient penalty訓練得到的gradient norm(如下圖左邊)和weight分佈(如下圖右邊),可以看到weight clipping學到的weight主要集中在兩個邊界值處,而使用gradient penalty學到的weight的分佈符合我們的設想。
下面是帶gradient penalty的WGAN的算法過程
DualGAN
論文:DualGAN:Unsupervised Dual Learning for Image-to-Image Translation
對偶學習,是出現在機器翻譯領域的一種新的學習範式,對偶學習最關鍵的一點在於給定一個原始任務模型,其對偶任務的模型可以給其提供反饋;同樣的,給定一個對偶任務的模型,其原始任務的模型也可以給該對偶任務的模型提供反饋,那麼這兩個互爲對偶的任務可以相互提供反饋,相互學習,相互提高。對偶學習是微軟亞洲研究院提出來的,見《對偶學習:一種新的機器學習範式》。
下面以一箇中英翻譯遊戲爲例,假設有兩個玩家Alice和小明,Alice講英文,小明講中文,兩人的目地是想要提高中譯英和英譯中模型的準確度。給定一個英文句子x,Alice首先通過英譯中模型f將句子x翻譯成中文,並將傳送給小明,小明雖然不知道Alice具體想要表達的意思,但是小明可以判斷收到的中文句子是不是語法正確,符不符合中文的語言模型,這些信息可以幫助小明大概判斷英譯中模型f是不是做的好,然後小明將這個中文的句子通過中譯英模型g翻譯成一個新的英文句子,併發給Alice,通過比較x和是不是相似,Alice就能夠知道英譯中模型f和中譯英模型g是不是做的好,儘管x是一個沒有標準的句子。可以看到面,這些互爲對偶的任務可以形成一個閉環,使從來沒有標註的數據中進行學習成爲可能。
DualGAN的靈感來源於上述機器翻譯中的對偶問題,如下圖所示,DualGAN中存在兩個生成器和兩個判別器,以素描和照片爲例,生成器對素描像u進行翻譯,其中包含噪聲z,翻譯結果爲,把這個翻譯結果作爲生成器的輸入,附上噪聲,翻譯結果爲,同理對於圖像v。判別器A判別一張圖片是否是photo,而判別器B判別一張圖片是否是sketch。
cycleGAN
論文:Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
cycleGAN與DualGAN模型相似。都包含兩個生成器和兩個判別器。cycleGAN主要目地是讓兩個domain裏的圖片互相轉化,思想主要是在不配對的訓練樣本中,從一類圖片中捕獲出特定的特徵,然後找出如何將這些特徵轉化成另一類圖片(capturing special characteristics of one image collection and figuring out how these characteristics counld be translated into the other image collection, all in the absence of any [aired training samples)
cycleGAN模型學習兩個domainX和Y之間的映射,對於訓練數據和,定義兩個映射關係,,和兩個判別器和,判別器的作用是區分圖片和轉換的圖片,同樣地判別器的作用是判別圖片和轉換的圖片。目標方程主要包含兩種類型,adversarial losses是爲了使得生成的圖片的分佈和目標域中的數據分佈相等,cycle consistency loss的目地是爲了防止兩個映射G和F互相矛盾。如果我們從domainX轉換到domain Y中然後再從domain Y轉換到X,應當回到原先開始的地方,看起來這是一個循環。下圖(b)中,對來自domain X中的圖像,經過圖像轉化環之後應當將還原,我們稱之爲一個forward cycle-consistency loss: ,下圖(c)中,對來自domainY中的圖像,經過圖像轉化環之後應當將還原一個backward cycle-consistency loss:
對於映射函數和它的判別器,損失函數表示爲:
同理對於映射函數和它的判別器,損失函數表示爲:
cycle consistency loss表示爲
最後,整個目標方程爲