TF2 GAN篇之GAN原理及推導【附WGAN代碼和數據集】

TF2 GAN篇之GAN原理及推導

相關文章導航



GAN的起源

在生成對抗網絡(Generative Adversarial Network,簡稱GAN)發明之前,變分自編碼器被認爲是理論完備,實現簡單,使用神經網絡訓練起來很穩定,生成的圖片逼近度也較高,但是人眼還是可以很輕易地分辨出真實圖片與機器生成的圖片。

2014 年,Université de Montréal 大學Yoshua Bengio(2019 年圖靈獎獲得者)的學生IanGoodfellow 提出了生成對抗網絡GAN [1],從而開闢了深度學習最炙手可熱的研究方向之一。從2014 年到2019 年,GAN 的研究穩步推進,研究捷報頻傳,最新的GAN 算法在圖片生成上的效果甚至達到了肉眼難辨的程度,着實令人振奮。由於GAN 的發明,Ian
Goodfellow 榮獲GAN 之父稱號,並獲得2017 年麻省理工科技評論頒發的35 Innovators
Under 35 獎項。如圖展示了從2014 年到2018 年,GAN 模型取得了非凡的效果,可以看到不管是圖片大小,還是圖片逼真度,都有了巨大的提升。

在這裏插入圖片描述

接下來,我們將從生活中博弈學習的實例出發,一步步引出GAN 算法的設計思想和模型結構。

博弈學習實例

我們用一個漫畫家的成長軌跡來形象介紹生成對抗網絡的思想。考慮一對雙胞胎兄弟,分別稱爲老二G 和老大D,G 學習如何繪製漫畫,D 學習如何鑑賞畫作。還在娃娃時代的兩兄弟,尚且只學會瞭如何使用畫筆和紙張,G 繪製了一張不明所以的畫作,如圖(a)所示,由於此時D 鑑別能力不高,覺得G 的作品還行,但是人物主體不夠鮮明。在D 的指引和鼓勵下,G 開始嘗試學習如何繪製主體輪廓和使用簡單的色彩搭配。
注: 字母D其實就算代表discriminator 字母G其實就算代表Generator

一年後,G 提升了繪畫的基本功,D 也通過分析名作和初學者G 的作品,初步掌握了鑑別作品的能力。此時D 覺得G 的作品人物主體有了,如第二張圖(b),但是色彩的運用還不夠成熟。數年後,G 的繪畫基本功已經很紮實了,可以輕鬆繪製出主體鮮明、顏色搭配合適和逼真度較高的畫作,如第三張圖©,但是D 同樣通過觀察G 和其它名作的差別,提升了畫作鑑別能力,覺得G 的畫作技藝已經趨於成熟,但是對生活的觀察尚且不夠,作品沒有傳達神情且部分細節不夠完美。又過了數年,G 的繪畫功力達到了爐火純青的地步,繪製的作品細節完美、風格迥異、惟妙惟肖,宛如大師級水準,如第4張圖(d),即便此時的D 鑑別功力也相當出色,亦很難將G 和其他大師級的作品區分開來。

在這裏插入圖片描述

在原始的GAN 論文中,Ian Goodfellow 使用了另一個形象的比喻來介紹GAN 模型:生成器網絡G 的功能就是產生一系列非常逼真的假鈔試圖欺騙鑑別器D,而鑑別器D 通過學習真鈔和生成器G 生成的假鈔來掌握鈔票的鑑別方法。這兩個網絡在相互博弈的過程中間同步提升,直到生成器G 產生的假鈔非常的逼真,連鑑別器D 都真假難辨。

這種博弈學習的思想使得GAN 的網絡結構和訓練過程與之前的網絡模型略有不同,下面我們來詳細介紹GAN 的網絡結構和算法原理。

GAN 原理

1. 網絡結構

生成對抗網絡包含了兩個子網絡:生成網絡(Generator,簡稱G)和判別網(Discriminator,簡稱D),其中生成網絡G 負責學習樣本的真實分佈,判別網絡D 負責將生成網絡採樣的樣本與真實樣本區分開來。

生成網絡G(𝒛) 生成網絡G 和自編碼器的Decoder 功能類似,從先驗分佈𝑝𝒛(∙)中採樣隱藏變量𝒛~𝑝𝒛(∙),通過生成網絡G 參數化的𝑝𝑔(𝒙|𝒛)分佈,獲得生成樣本𝒙~𝑝𝑔(𝒙|𝒛), 如圖所示。其中隱藏變量𝒛的先驗分佈𝑝𝒛(∙)可以假設爲某中已知的分佈,比如多元均勻分佈𝑧~Uniform(−1,1)。
在這裏插入圖片描述
𝑝𝑔(𝒙|𝒛)可以用深度神經網絡來參數化,如下圖 13.4 所示,從均勻分佈𝑝𝒛(∙)中採樣出隱藏變量𝒛,經過多層轉置卷積層網絡參數化的𝑝𝑔(𝒙|𝒛)分佈中採樣出樣本𝒙𝑓。從輸入輸出層面來看,生成器G 的功能是將隱向量𝒛通過神經網絡轉換爲樣本向量𝒙𝑓,下標𝑓代表假樣本(Fake samples)。

在這裏插入圖片描述
判別網絡D(𝒙) 判別網絡和普通的二分類網絡功能類似,它接受輸入樣本𝒙的數據集,包含了採樣自真實數據分佈𝑝𝑟(∙)的樣本𝒙𝑟𝑝𝑟(∙),也包含了採樣自生成網絡的假樣本𝒙𝑓𝑝𝑔(𝒙|𝒛),𝒙𝑟和𝒙𝑓共同組成了判別網絡的訓練數據集。判別網絡輸出爲𝒙屬於真實樣本的概率𝑃(𝒙爲真|𝒙),我們把所有真實樣本𝒙𝑟的標籤標註爲真(1),所有生成網絡產生的樣本𝒙𝑓標註爲假(0),通過最小化判別網絡D 的預測值與標籤之間的誤差來優化判別網絡參數,如圖所示

在這裏插入圖片描述

2. 網絡訓練

GAN 博弈學習的思想體現在在它的訓練方式上,由於生成器G 和判別器D 的優化目標不一樣,不能和之前的網絡模型的訓練一樣,只採用一個損失函數。下面我們來分別介紹如何訓練生成器G 和判別器D。

對於判別網絡D,它的目標是能夠很好地分辨出真樣本𝒙𝑟與假樣本𝒙𝑓。以圖片生成爲例,它的目標是最小化圖片的預測值和真實值之間的交叉熵損失函數
在這裏插入圖片描述
其中𝐷𝜃(𝒙𝑟)代表真實樣本𝒙𝑟在判別網絡𝐷𝜃的輸出,𝜃爲判別網絡的參數集,𝐷𝜃(𝒙𝑓)爲生成樣本𝒙𝑓在判別網絡的輸出,𝑦𝑟爲𝒙𝑟的標籤,由於真實樣本標註爲真,故𝑦𝑟 = 1,𝑦𝑓爲生成樣本的𝒙𝑓的標籤,由於生成樣本標註爲假,故𝑦𝑓 = 0。CE 函數代表交叉熵損失函數CrossEntropy。二分類問題的交叉熵損失函數定義爲:
在這裏插入圖片描述

以下的推導過程用手寫的形式進行
在這裏插入圖片描述

統一目標函數

我們把判別網絡的目標和生成網絡的目標合併,寫成min − max博弈形式:

說起來比較高大上,其實就算加起來嘿嘿

在這裏插入圖片描述在這裏插入圖片描述

GAN變種

在原始的GAN 論文中,Ian Goodfellow 從理論層面分析了GAN 網絡的收斂性,並且在多個經典圖片數據集上測試了圖片生成的效果,如圖 13.9 所示,其中圖(a)爲MNIST 數據,圖(b)爲Toronto Face 數據集,圖 ( c )、圖(d)爲CIFAR10 數據集。

在這裏插入圖片描述
可以看到,原始GAN 模型在圖片生成效果上並不突出,和VAE 差別不明顯,此時並沒有展現出它強大的分佈逼近能力。但是由於GAN 在理論方面較新穎,實現方面也有很多可以改進的地方,大大地激發了學術界的研究興趣。在接下來的數年裏,GAN 的研究如火如荼的進行,並且也取得了實質性的進展。接下來我們將介紹幾個意義比較重大的GAN變種。

DCGAN

最初始的GAN 網絡主要基於全連接層實現生成器G 和判別器D 網絡,由於圖片的維度較高,網絡參數量巨大,訓練的效果並不優秀。DCGAN [2]提出了使用轉置卷積層實現的生成網絡,普通卷積層來實現的判別網絡,大大地降低了網絡參數量,同時圖片的生成效果也大幅提升,展現了GAN 模型在圖片生成效果上超越VAE 模型的潛質。此外,DCGAN 作者還提出了一系列經驗性的GAN 網絡訓練技巧,這些技巧在WGAN 提出之前被證實有益於網絡的穩定訓練。

InfoGAN

InfoGAN [3]嘗試使用無監督的方式去學習輸入𝒙的可解釋隱向量𝒛的表示方法(Interpretable Representation),即希望隱向量𝒛能夠對應到數據的語義特徵。比如對於MNIST 手寫數字圖片,我們可以認爲數字的類別、字體大小和書寫風格等是圖片的隱藏變量,希望模型能夠學習到這些分離的(Disentangled)可解釋特徵表示方法,從而可以通過人爲控制隱變量來生成指定內容的樣本。對於CelebA 名人照片數據集,希望模型可以把髮型、眼鏡佩戴情況、面部表情等特徵分隔開,從而生成指定形態的人臉圖片。分離的可解釋特徵有什麼好處呢?它可以讓神經網絡的可解釋性更強,比如𝒛包含了一些分離的可解釋特徵,那麼我們可以通過僅僅改變這一個位置上面的特徵來獲得不同語義的生成數據,如圖所示,通過將“戴眼鏡男士”與“不戴眼鏡男士”的隱向量相減,並與“不戴眼鏡女士”的隱向量相加,可以生成“戴眼鏡女士”的生成圖片。

在這裏插入圖片描述

CycleGAN

CycleGAN 是華人朱儁彥提出的無監督方式進行圖片風格相互轉換的算法,由於算法清晰簡單,實驗效果完成的較好,這項工作受到了很多的讚譽。CycleGAN 基本的假設是,如果由圖片A 轉換到圖片B,再從圖片B 轉換到A′,那麼A′應該和A 是同一張圖片。因此除了設立標準的GAN 損失項外,CycleGAN 還增設了循環一致性損失(CycleConsistency Loss),來保證A′儘可能與A 逼近。CycleGAN 圖片的轉換效果如圖所示。

在這裏插入圖片描述

WGAN

GAN 的訓練問題一直被詬病,很容易出現訓練不收斂和模式崩塌的現象。WGAN 從理論層面分析了原始的GAN 使用JS 散度存在的缺陷,並提出了可以使用Wasserstein 距離來解決這個問題。在WGAN-GP [6]中,作者提出了通過添加梯度懲罰項,從工程層面很好的實現了WGAN 算法,並且實驗性證實了WGAN 訓練穩定的優點。

Equal GAN

從 GAN 的誕生至2017 年底,GAN Zoo 已經收集超過了214 種GAN 網絡變種③。這些GAN 的變種或多或少地提出了一些創新,然而Google Brain 的幾位研究員在論文中提供了另一個觀點:沒有證據表明我們測試的GAN 變種算法一直持續地比最初始的GAN要好。論文中對這些GAN 變種進行了相對公平、全面的比較,在有足夠計算資源的情況下,發現幾乎所有的GAN 變種都能達到相似的性能(FID 分數)。這項工作提醒業界是否這些GAN 變種具有本質上的創新

在 SAGAN 的基礎上,BigGAN [9]嘗試將GAN 的訓練擴展到大規模上去,利用正交
正則化等技巧保證訓練過程的穩定性。BigGAN 的意義在於啓發人們,GAN 網絡的訓練同樣可以從大數據、大算力等方面受益。BigGAN 圖片生成效果達到了前所未有的高度:
Inception score 記錄提升到166.5(提高了52.52);Frechet Inception Distance 下降到7.4,降低了18.65,如圖所示,圖片的分辨率可達512 × 512,圖片細節極其逼真。
在這裏插入圖片描述

納什均衡

現在我們從理論層面進行分析,通過博弈學習的訓練方式,生成器G 和判別器D 分別會達到什麼平衡狀態。具體地,我們將探索以下兩個問題:

❑ 固定 G,D 會收斂到什麼最優狀態𝐷∗?
❑ 在 D 達到最優狀態𝐷∗後,G 會收斂到什麼狀態?

首先我們通過𝒙𝑟~𝑝𝑟(∙)一維正態分佈的例子給出一個直觀的解釋。如圖所示,黑色虛線曲線代表了真實數據的分佈𝑝𝑟(∙),爲某正態分佈𝒩(𝜇, 𝜎2),綠色實線代表了生成網絡學習到的分佈𝒙𝑓~𝑝𝑔(∙),藍色虛線代表了判別器的決策邊界曲線,圖(a)、(b)、©、(d)分別代表了生成網絡的學習軌跡。在初始狀態,如圖(a)所示,𝑝𝑔(∙)分佈與𝑝𝑟(∙)差異較大,判別器可以很輕鬆地學習到明確的決策邊界,即圖(a)中的藍色虛線,將來自𝑝𝑔(∙)的採樣點判定爲0,𝑝𝑟(∙)中的採樣點判定爲1。隨着生成網絡的分佈𝑝𝑔(∙)越來越逼近真實分佈𝑝𝑟(∙),判別器越來越困難將真假樣本區分開,如圖(b)©所示。最後,生成網絡學習到的分佈𝑝𝑔(∙) =𝑝𝑟(∙)時,此時從生成網絡中採樣的樣本非常逼真,判別器無法區分,即判定爲真假樣本的概率均等,如圖(d)所示
在這裏插入圖片描述

判別器狀態

現在來推導第一個問題。回顧GAN 的損失函數:
在這裏插入圖片描述
對於判別器D,優化的目標是最大化ℒ(𝐺, 𝐷)函數,需要找出函數:

在這裏插入圖片描述

的最大值,其中𝜃爲判別器𝐷的網絡參數。

我們來考慮𝑓𝜃更通用的函數的最大值情況:
在這裏插入圖片描述
要求得函數𝑓(𝑥)的最大值。考慮𝑓(𝑥)的導數:

在這裏插入圖片描述在這裏插入圖片描述
當對x的偏導爲0時,可以求得𝑓(𝑥)函數的極值點:

在這裏插入圖片描述
因此,可以得知,𝑓𝜃函數的極值點同樣爲:
在這裏插入圖片描述
也就是說,判別器網絡𝐷𝜃處於𝐷𝜃∗狀態時,𝑓𝜃函數取得最大值,ℒ(𝐺, 𝐷)函數也取得最大
值。
現在回到最大化ℒ(𝐺, 𝐷)的問題,ℒ(𝐺, 𝐷)的最大值點在:在這裏插入圖片描述
時取得,此時也是𝐷𝜃的最優狀態𝐷∗

生成器狀態

在推導第二個問題之前,我們先介紹一下與KL 散度類似的另一個分佈距離度量標準:JS 散度,它定義爲KL 散度的組合:
在這裏插入圖片描述
JS 散度克服了KL 散度不對稱的缺陷。
當 D 達到最優狀態𝐷∗時,我們來考慮此時𝑝r和𝑝g的JS 散度:
在這裏插入圖片描述
根據 KL 散度的定義展開爲:
在這裏插入圖片描述
合併常數項可得:
在這裏插入圖片描述在這裏插入圖片描述
考慮在判別網絡到達𝐷∗時,此時的損失函數爲:
在這裏插入圖片描述
因此在判別網絡到達𝐷∗時,𝐷𝐽𝑆(𝑝𝑟||𝑝𝑔)與ℒ(𝐺, 𝐷∗)滿足關係:
在這裏插入圖片描述

在這裏插入圖片描述
此時生成網絡𝐺∗的狀態是 : 𝑝𝑔 = 𝑝𝑟
即𝐺∗的學到的分佈𝑝𝑔與真實分佈𝑝𝑟一致,網絡達到平衡點,此時:
在這裏插入圖片描述

納什均衡點

通過上面的推導,我們可以總結出生成網絡G 最終將收斂到真實分佈,即:𝑝𝑔 = 𝑝𝑟

此時生成的樣本與真實樣本來自同一分佈,真假難辨,在判別器中均有相同的概率判定爲
真或假,即

𝐷(∙) = 0.5
此時損失函數爲:
在這裏插入圖片描述

GAN 訓練難題

儘管從理論層面分析了GAN 網絡能夠學習到數據的真實分佈,但是在工程實現中,常常出現GAN 網絡訓練困難的問題,主要體現在GAN 模型對超參數較爲敏感,需要精心挑選能使模型工作的超參數設定,同時也容易出現模式崩塌現象。

1. 超參數敏感

超參數敏感是指網絡的結構設定、學習率、初始化狀態等超參數對網絡的訓練過程影響較大,微量的超參數調整將可能導致網絡的訓練結果截然不同。如圖 13.15 所示,圖(a)爲GAN 模型良好訓練得到的生成樣本,圖(b)中的網絡由於沒有采用Batch Normalization層等設置,導致GAN 網絡訓練不穩定,無法收斂,生成的樣本與真實樣本差距非常大。

在這裏插入圖片描述

2. 模式崩塌

模式崩塌(Mode Collapse)是指模型生成的樣本單一,多樣性很差的現象。由於判別器只能鑑別單個樣本是否採樣自真實分佈,並沒有對樣本多樣性進行顯式約束,導致生成模型可能傾向於生成真實分佈的部分區間中的少量高質量樣本,以此來在判別器中獲得較高的概率值,而不會學習到全部的真實分佈。模式崩塌現象在GAN 中比較常見,如圖
所示,在訓練過程中,通過可視化生成網絡的樣本可以觀察到,生成的圖片種類非常單一,生成網絡總是傾向於生成某種單一風格的樣本圖片,以此騙過判別器。

在這裏插入圖片描述
另一個直觀地理解模式崩塌的例子如圖 13.17 所示,第一行爲未出現模式崩塌現象的生成網絡的訓練過程,最後一列爲真實分佈,即2D 高斯混合模型;第二行爲出現模式崩塌現象的生成網絡的訓練過程,最後一列爲真實分佈。可以看到真實的分佈由8 個高斯模型混合而成,出現模式崩塌後,生成網絡總是傾向於逼近真實分佈的某個狹窄區間,如圖第2 行前6 列所示,從此區間採樣的樣本往往能夠在判別器中較大概率判斷爲真實樣本,從而騙過判別器。但是這種現象並不是我們希望看到的,我們希望生成網絡能夠逼近真實的分佈,而不是真實分佈中的某部分。
在這裏插入圖片描述
那麼怎麼解決GAN 訓練的難題,讓GAN 可以像普通的神經網絡一樣訓練較爲穩定
呢?WGAN 模型給出了一種解決方案。

WGAN原理

WGAN 算法從理論層面分析了GAN 訓練不穩定的原因,並提出了有效的解決方法。那麼是什麼原因導致了GAN 訓練如此不穩定呢?WGAN 提出是因爲JS 散度在不重疊的分佈𝑝和𝑞上的梯度曲面是恆定爲0 的。如圖 13.19 所示,當分佈𝑝和𝑞不重疊時,JS 散度的梯度值始終爲0,從而導致此時GAN 的訓練出現梯度彌散現象,參數長時間得不到更新,網絡無法收斂。

接下來我們將詳細闡述JS 散度的缺陷以及怎麼解決此缺陷。

1.JS散度的缺陷

爲了避免過多的理論推導,我們這裏通過一個簡單的分佈實例來解釋JS 散度的缺陷。考慮完全不重疊(𝜃 ≠ 0)的兩個分佈𝑝和𝑞,其中分佈𝑝爲:

∀(𝑥, 𝑦) ∈ p, 𝑥 = 𝜃, 𝑦 ∼ U(0,1)

分佈𝑞爲:

∀(𝑥, 𝑦) ∈ 𝑞, 𝑥 = 𝜃, 𝑦 ∼ U(0,1)

其中𝜃 ∈ 𝑅,當𝜃 = 0時,分佈𝑝和𝑞重疊,兩者相等;當𝜃 ≠ 0時,分佈𝑝和𝑞不重疊。
在這裏插入圖片描述
我們來分析上述分佈𝑝和𝑞之間的JS 散度隨𝜃的變化情況。根據KL 散度與JS 散度的定義,計算𝜃 = 0時的JS 散度𝐷𝐽𝑆(𝑝||𝑞):

在這裏插入圖片描述
當𝜃 = 0時,兩個分佈完全重疊,此時的JS 散度和KL 散度都取得最小值,即0:
𝐷𝐾𝐿(𝑝||𝑞) = 𝐷𝐾𝐿 (𝑞||𝑝) = 𝐷𝐽𝑆(𝑝||𝑞) = 0
從上面的推導,我們可以得到𝐷𝐽𝑆(𝑝||𝑞)隨𝜃的變化趨勢:
在這裏插入圖片描述
也就是說,當兩個分佈完全不重疊時,無論分佈之間的距離遠近,JS 散度爲恆定值log2,此時JS 散度將無法產生有效的梯度信息;當兩個分佈出現重疊時,JS 散度纔會平滑變動,產生有效梯度信息;當完全重合後,JS 散度取得最小值0。如圖 13.19 中所示,紅色的曲線分割兩個正態分佈,由於兩個分佈沒有重疊,生成樣本位置處的梯度值始終爲0,無法更新生成網絡的參數,從而出現網絡訓練困難的現象。
在這裏插入圖片描述
因此,JS 散度在分佈𝑝和𝑞不重疊時是無法平滑地衡量分佈之間的距離,從而導致此位置上無法產生有效梯度信息,出現GAN 訓練不穩定的情況。要解決此問題,需要使用一種更好的分佈距離衡量標準,使得它即使在分佈𝑝和𝑞不重疊時,也能平滑反映分佈之間的真實距離變化

2. EM距離

WGAN 論文發現了JS 散度導致GAN 訓練不穩定的問題,並引入了一種新的分佈距離度量方法:Wasserstein 距離,也叫推土機距離(Earth-Mover Distance,簡稱EM 距離),它表示了從一個分佈變換到另一個分佈的最小代價,定義爲:
在這裏插入圖片描述
其中Π(𝑝, 𝑞**)是分佈𝑝和𝑞組合起來的所有可能的聯合分佈的集合**,對於每個可能的聯合分佈𝛾 ∼ Π(𝑝, 𝑞),計算距離‖𝑥 − 𝑦‖的期望𝔼(𝑥,𝑦)∼𝛾[‖𝑥 − 𝑦‖],其中(𝑥, 𝑦)採樣自聯合分佈𝛾。不同的聯合分佈𝛾有不同的期望𝔼(𝑥,𝑦)∼𝛾[‖𝑥 − 𝑦‖],這些期望中的下確界即定義爲分佈𝑝和𝑞的Wasserstein 距離。其中inf{∙}表示集合的下確界,例如{𝑥|1 < 𝑥 < 3, 𝑥 ∈ 𝑅}的下確界爲1。

繼續考慮上圖中的例子,我們直接給出分佈𝑝和𝑞之間的EM 距離的表達式:

𝑊(𝑝, 𝑞) = |𝜃|

繪製出 JS 散度和EM 距離的曲線,如圖所示,可以看到,JS 散度在𝜃 = 0處不連續,其他位置導數均爲0,而EM 距離總能夠產生有效的導數信息,因此EM 距離相對於JS 散度更適合指導GAN 網絡的訓練。

在這裏插入圖片描述

WGAN-GP

考慮到幾乎不可能遍歷所有的聯合分佈𝛾去計算距離‖𝑥 − 𝑦‖的期望𝔼(𝑥,𝑦)∼𝛾[‖𝑥 − 𝑦‖],
因此直接計算生成網絡分佈𝑝𝑔與真實數據分佈𝑝𝑟的𝑊(𝑝𝑟, 𝑝𝑔)距離是不現實的,WGAN 作
者基於Kantorovich-Rubinstein 對偶性將直接求𝑊(𝑝𝑟, 𝑝𝑔)轉換爲求:
在這裏插入圖片描述
其中𝑠𝑢𝑝{∙}表示集合的上確界,||𝑓||𝐿 ≤ 𝐾表示函數𝑓: 𝑅 → 𝑅滿足K-階Lipschitz 連續性,即
滿足
在這裏插入圖片描述
於是,我們使用判別網絡𝐷𝜃(𝒙)參數化𝑓(𝒙)函數,在𝐷𝜃滿足1 階-Lipschitz 約束的條件下,即𝐾 = 1,此時:

在這裏插入圖片描述
因此求解𝑊(𝑝𝑟, 𝑝𝑔)的問題可以轉化爲:

在這裏插入圖片描述
這就是判別器D 的優化目標。判別網絡函數𝐷𝜃(𝒙)需要滿足1 階-Lipschitz 約束:∇𝒙̂𝐷(𝒙̂) ≤ 𝐼
在 WGAN-GP 論文中,作者提出採用增加梯度懲罰項(Gradient Penalty)方法來迫使判別網絡滿足1 階-Lipschitz 函數約束,同時作者發現將梯度值約束在1 周圍時工程效果更好,因此梯度懲罰項定義爲: GP ≜ 𝔼𝒙̂∼𝑃𝒙̂ [(‖𝛻𝒙̂𝐷(𝒙̂)‖2 − 1)2]

因此 WGAN 的判別器D 的訓練目標爲:
在這裏插入圖片描述
其中𝒙̂來自於𝒙𝑟與𝒙𝑟的線性差值:

𝑥̂ = 𝑡𝒙𝑟 + (1 − 𝑡)𝒙𝑓 , 𝑡 ∈ [0,1]

判別器 D 的目標是最小化上述的誤差ℒ(𝐺, 𝐷),即迫使生成器G 的分佈𝑝𝑔與真實分佈𝑝𝑟之間EM 距離𝔼𝒙𝑟∼𝑝𝑟[𝐷(𝒙𝑟)]−𝔼𝒙𝑓∼𝑝𝑔 [𝐷(𝒙𝑓)]項儘可能大,‖𝛻𝒙̂𝐷(𝒙̂)‖2逼近於1。

WGAN 的生成器G 的訓練目標爲
在這裏插入圖片描述

即使得生成器的分佈𝑝𝑔與真實分佈𝑝𝑟之間的EM 距離越小越好。考慮到𝔼𝒙𝑟∼𝑝𝑟[𝐷(𝒙𝑟)]一項與生成器無關,因此生成器的訓練目標簡寫爲:
在這裏插入圖片描述
從實現來看,判別網絡D 的輸出不需要添加Sigmoid 激活函數,這是因爲原始版本的判別器的功能是作爲二分類網絡,添加Sigmoid 函數獲得類別的概率;而WGAN 中判別器作爲EM 距離的度量網絡,其目標是衡量生成網絡的分佈𝑝𝑔和真實分佈𝑝𝑟之間的EM 距離,屬於實數空間,因此不需要添加Sigmoid 激活函數。在誤差函數計算時,WGAN 也沒有log 函數存在。在訓練WGAN 時,WGAN 作者推薦使用RMSProp 或SGD 等不帶動量的優化器

WGAN 從理論層面發現了原始GAN 容易出現訓練不穩定的原因,並給出了一種新的距離度量標準和工程實現解決方案,取得了較好的效果。WGAN 還在一定程度上緩解了模式崩塌的問題,使用WGAN 的模型不容易出現模式崩塌的現象。需要注意的是,WGAN一般並不能提升模型的生成效果,僅僅是保證了模型訓練的穩定性當然,保證模型能夠穩定地訓練也是取得良好效果的前提。如圖所示,原始版本的DCGAN 在不使用BN 層等設定時出現了訓練不穩定的現象,在同樣設定下,使用WGAN 來訓練判別器可以避免此現象,如圖所示。

不帶BN層的DCGAN生成器效果
在這裏插入圖片描述
不帶BN層的WGAN生成器效果
在這裏插入圖片描述

WGAN實戰代碼和資源

鏈接:https://pan.baidu.com/s/1sncma9kCQ5CzyqqMpGgOkA
提取碼:d693
複製這段內容後打開百度網盤手機App,操作更方便哦


參考書籍: TensorFlow 深度學習 — 龍龍老師

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章