【深度學習】生成式對抗網絡原理與GAN/WGAN/WGAN-GP初步

GAN

生成對抗網絡(GAN, Generative Adversarial Networks)相比於CNN與RNN是一項比較新的技術,由Ian J. GoodFellow於2014年提出,隨後不斷髮展,目前已經廣泛用於圖像生成、圖像合成與圖像增強領域。
GAN創新性地將機器學習中的生成模型與判別模型結合了起來,其目的從數學的角度來講是訓練得到的生成器能夠通過輸入隨機的噪聲生成符合指定分佈的數據,例如圖像數據,從而起到圖像合成、生成等一系列目的。
通俗地講,GAN採用了一種博弈的思想,通過設置特殊的代價函數,使生成器生成分類器越來越難以分辨的數據,而分類器經過訓練又逐漸提高其區分生成數據分佈與真實分佈數據的能力。最終兩者不斷促進,分類器的分辨能力不斷提高,生成器生成的數據越來越符合指定分佈,幾乎能夠以假亂真。
這樣巧妙的思路自然也是在數學理論支撐之上的。對於GAN,除了設計兩者合理有效的模型,很關鍵的一點是設計的代價函數,這是因爲代價函數的設計決定了學習的最終目標與學習的效果。
儘管GAN出現於最近10年,其數學理論可以追溯到最大化均值差異(MMD),其用來衡量兩個數據分佈的差異性程度。
理論上,如果兩個分佈一致,其均值相同且對於各種複雜的非線性變換得到的結果均值應該仍然相同。其實這裏說數學期望更加準確,但爲了簡化以及針對有限的數據,往往採用均值來計算。其公式如下:


這樣的公式對SGD並不優化,因爲均值和數學期望一般都是建立在完整的數據下。經過調整,將平方更改爲內積,得到的公式如下:

從公式可以看出,在具體優化時至少需要4個樣本,2個來自p分佈,兩個來自q分佈。另外,這裏變換的選取有很多種,理論上要求Lipschitz係數小於等於1,其實一般來說有界即可,防止因爲較大的係數導致發散。過去流行的變換包括在SVM時代採用的各類核函數或者多核學習,對不同的核函數進行線性組合得到更復雜的變換。
圖1 GAN的基本思想
GAN從一另個角度來看待數據分佈相似性問題,認爲如果分佈相同那麼就無法學到一個分類器將其區分開,那麼就可以先嚐試對噪聲與目標分佈分類,然後固定住分類器更新生成器,然後在嘗試分類,以此類推。
圖2 不斷接近的分佈
這樣變得到了如下函數:

首先看max部分,前面的部分是真實數據的得到的數學期望,後面是“1 – 生成數據”的數學期望,即認爲其爲假的數學期望,這個max的目標則是使分類器的分類能力得到提高;max以後min的目的則是使生成器得到的數據更能以假亂真。從後續代碼實現介紹部分可以看到對其的遵循。
事實上,由於每步無法學到全局最優,這樣的生成網絡會存在模式坍塌的問題,難以學到目標的分佈。尤其是對於傳統的SGD,幾乎是難以避免的發生振盪,收斂難度較大。
GoodFellow後來也嘗試過使用MMD,將神經網絡看做複雜的非線性變換,但因爲缺少對Lipschitz係數限制的考慮存在一定的侷限性。

WGAN

WGAN即Wasserstein GAN,與2017年提出。WGAN主要的改進是對GAN的目標函數,在MMD的基礎上做了更深入的理論分析與研究。

WGAN使用了統計學上的Wasserstein距離。基於問題的單邊特性,Wasserstein距離將MMD的平方與絕對值都去除,去除平方消除了一定的放大作用,從求導的角度對優化也更加有利。另外,爲了使Lipschitz係數小於等於1,WGAN採用clip的方式進行截取,儘管保證了有界,但是與目標的效果是不一樣的,因此會存在一些問題。

WGAN-GP

圖3 WGAN/WGAN-GP效果對比
WGAN-GP在WGAN的基礎上進行了改進,通過在目標函數後加入gradient penalty來實現對Lipschitz係數的限制。Lipschitz係數可以理解爲對x求梯度中最大的梯度,WGAN-GP實際上是加入了對梯度的一種限制,通過加入到目標函數中使之成爲優化的目標。

這裏梯度代價的目標實際上是讓偏導接近於1,這是因爲最優解很有可能在這種情況出現,能夠充分利用約束條件。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章