TensorFlow 2.0深度學習算法實戰---第13章 生成對抗網絡

我不能創造的事物,我就還沒有完全理解它。−理查德·費曼

生成對抗網絡(Generative Adversarial Network,簡稱 GAN)發明之前,變分自編碼器被認爲是理論完備,實現簡單,使用神經網絡訓練起來很穩定,生成的圖片逼近度也較高,但是人眼還是可以很輕易地分辨出真實圖片與機器生成的圖片。

2014 年,Université de Montréal 大學 Yoshua Bengio(2019 年圖靈獎獲得者)的學生 Ian Goodfellow 提出了生成對抗網絡 GAN,從而開闢了深度學習最炙手可熱的研究方向之一。從 2014 年到 2019 年,GAN 的研究穩步推進,研究捷報頻傳,最新的 GAN 算法在圖片生成上的效果甚至達到了肉眼難辨的程度,着實令人振奮。由於 GAN 的發明,IanGoodfellow 榮獲 GAN 之父稱號,並獲得 2017 年麻省理工科技評論頒發的 35 Innovators Under 35 獎項。圖 13.1 展示了從 2014 年到 2018 年,GAN 模型取得了圖書生成的效果,可以看到不管是圖片大小,還是圖片逼真度,都有了巨大的提升。
在這裏插入圖片描述
接下來,我們將從生活中博弈學習的實例出發,一步步引出 GAN 算法的設計思想和模型結構。

13.1 博弈學習實例

我們用一個漫畫家的成長軌跡來形象介紹生成對抗網絡的思想。考慮一對雙胞胎兄弟,分別稱爲老二 G 和老大 D,G 學習如何繪製漫畫,D 學習如何鑑賞畫作。還在娃娃時代的兩兄弟,尚且只學會瞭如何使用畫筆和紙張,G 繪製了一張不明所以的畫作,如圖13.2(a)所示,由於此時 D 鑑別能力不高,覺得 G 的作品還行,但是人物主體不夠鮮明。在D 的指引和鼓勵下,G 開始嘗試學習如何繪製主體輪廓和使用簡單的色彩搭配。
在這裏插入圖片描述
一年後,G 提升了繪畫的基本功,D 也通過分析名作和初學者 G 的作品,初步掌握了鑑別作品的能力。此時 D 覺得 G 的作品人物主體有了,如圖 13.2(b),但是色彩的運用還不夠成熟。數年後,G 的繪畫基本功已經很紮實了,可以輕鬆繪製出主體鮮明、顏色搭配合適和逼真度較高的畫作,如圖13.2©,但是 D 同樣通過觀察 G 和其它名作的差別,提升了畫作鑑別能力,覺得 G 的畫作技藝已經趨於成熟,但是對生活的觀察尚且不夠,作品沒有傳達神情且部分細節不夠完美。又過了數年,G 的繪畫功力達到了爐火純青的地步,繪製的作品細節完美、風格迥異、惟妙惟肖,宛如大師級水準,如圖 13.2(d),即便此時的D 鑑別功力也相當出色,亦很難將 G 和其他大師級的作品區分開來。

上述畫家的成長曆程其實是一個生活中普遍存在的學習過程,通過雙方的博弈學習,相互提高,最終達到一個平衡點。GAN 網絡借鑑了博弈學習的思想,分別設立了兩個子網絡:負責生成樣本的生成器 G 和負責鑑別真僞的鑑別器 D。類比到畫家的例子,生成器 G就是老二,鑑別器 D 就是老大。鑑別器 D 通過觀察真實的樣本和生成器 G 產生的樣本之間的區別,學會如何鑑別真假,其中真實的樣本爲真,生成器 G 產生的樣本爲假。而生成器 G 同樣也在學習,它希望產生的樣本能夠獲得鑑別器 D 的認可,即在鑑別器 D 中鑑別爲真,因此生成器 G 通過優化自身的參數,嘗試使得自己產生的樣本在鑑別器 D 中判別爲真。生成器 G 和鑑別器 D 相互博弈,共同提升,直至達到平衡點。此時生成器 G 生成的樣本非常逼真,使得鑑別器 D 真假難分。

在原始的 GAN 論文中,Ian Goodfellow 使用了另一個形象的比喻來介紹 GAN 模型:生成器網絡 G 的功能就是產生一系列非常逼真的假鈔試圖欺騙鑑別器 D,而鑑別器 D 通過學習真鈔和生成器 G 生成的假鈔來掌握鈔票的鑑別方法。這兩個網絡在相互博弈的過程中間同步提升,直到生成器 G 產生的假鈔非常的逼真,連鑑別器 D 都真假難辨。

這種博弈學習的思想使得 GAN 的網絡結構和訓練過程與之前的網絡模型略有不同,下面我們來詳細介紹 GAN 的網絡結構和算法原理。

13.2 GAN 原理

現在我們來正式介紹生成對抗網絡的網絡結構和訓練方法。

13.2.1 網絡結構

生成對抗網絡包含了兩個子網絡:生成網絡(Generator,簡稱 G)和判別網絡(Discriminator,簡稱 D),其中生成網絡 G 負責學習樣本的真實分佈,判別網絡 D 負責將生成網絡採樣的樣本與真實樣本區分開來。

生成網絡G(𝒛) 生成網絡 G 和自編碼器的 Decoder 功能類似,從先驗分佈𝑝𝒛(∙)中採樣隱藏變量𝒛~𝑝𝒛(∙),通過生成網絡 G 參數化的pg(xz)p_{g}(x | z)分佈,獲得生成樣本xpg(xz)\boldsymbol{x} \sim p_{g}(\boldsymbol{x} | \mathbf{z})如圖13.3 所示。其中隱藏變量𝒛的先驗分佈𝑝𝒛(∙)可以假設爲某中已知的分佈,比如多元均勻分佈zz \sim Uniform (-1,1)。
在這裏插入圖片描述
pg(xz)p_{g}(x | z)可以用深度神經網絡來參數化,如下圖 13.4 所示,從均勻分佈𝑝𝒛(∙)中採樣出隱藏變量𝒛,經過多層轉置卷積層網絡參數化的pg(xz)p_{g}(x | z)分佈中採樣出樣本xfx_{f}。從輸入輸出層面來看,生成器 G 的功能是將隱向量𝒛通過神經網絡轉換爲樣本向量xfx_{f},下標𝑓代表假樣本(Fake samples)。
在這裏插入圖片描述
判別網絡D(𝒙) 判別網絡和普通的二分類網絡功能類似,它接受輸入樣本𝒙的數據集,包含了採樣自真實數據分佈pr()p_{r}(\cdot)的樣本xrpr()x_{r} \sim p_{r}(\cdot),也包含了採樣自生成網絡的假樣本xfpg(xz)\boldsymbol{x}_{f} \sim p_{g}(\boldsymbol{x} | \mathbf{z})xrx_{r}xfx_{f}共同組成了判別網絡的訓練數據集。判別網絡輸出爲𝒙屬於真實樣本的概率𝑃(𝒙爲真|𝒙),我們把所有真實樣本xrx_{r}的標籤標註爲真(1),所有生成網絡產生的樣本xfx_{f}標註爲假(0),通過最小化判別網絡 D 的預測值與標籤之間的誤差來優化判別網絡參數,如圖 13.5 所示。
在這裏插入圖片描述

13.2.2 網絡訓練

GAN 博弈學習的思想體現在在它的訓練方式上,由於生成器 G 和判別器 D 的優化目標不一樣,不能和之前的網絡模型的訓練一樣,只採用一個損失函數。下面我們來分別介紹如何訓練生成器 G 和判別器 D。

對於判別網絡 D,它的目標是能夠很好地分辨出真樣本xrx_{r}與假樣本xfx_{f}。以圖片生成爲例,它的目標是最小化圖片的預測值和真實值之間的交叉熵損失函數
minθL=CE(Dθ(xr),yr,Dθ(xf),yf)\min _{\theta} \mathcal{L}=\operatorname{CE}\left(D_{\theta}\left(\boldsymbol{x}_{r}\right), y_{r}, D_{\theta}\left(\boldsymbol{x}_{f}\right), y_{f}\right)
其中Dθ(xr)D_{\theta}\left(\boldsymbol{x}_{r}\right)代表真實樣本xrx_{r}在判別網絡DθD_{\theta}的輸出,𝜃爲判別網絡的參數集,Dθ(xf)D_{\theta}\left(\boldsymbol{x}_{f}\right)爲生成樣本xfx_{f}在判別網絡的輸出,yry_{r}xrx_{r}的標籤,由於真實樣本標註爲真,故yry_{r} = 1,yfy_{f}爲生成樣本的xfx_{f}的標籤,由於生成樣本標註爲假,故yfy_{f} = 0。CE函數代表交叉熵損失函數CrossEntropy。二分類問題的交叉熵損失函數定義爲:
L=xrpr()logDθ(xr)xfpg()log(1Dθ(xf))\mathcal{L}=-\sum_{x_{r} \sim p_{r}()} \log D_{\theta}\left(x_{r}\right)-\sum_{x_{f} \sim p_{g}(\cdot)} \log \left(1-D_{\theta}\left(x_{f}\right)\right)
因此判別網絡 D 的優化目標是:
θ=argminθxrpr()logDθ(xr)xfpg()log(1Dθ(xf))\theta^{*}=\underset{\theta}{\operatorname{argmin}}-\sum_{x_{r} \sim p_{r}(\cdot)} \log D_{\theta}\left(x_{r}\right)-\sum_{x_{f} \sim p_{g}(\cdot)} \log \left(1-D_{\theta}\left(x_{f}\right)\right)
minθL\min _{\theta} \mathcal{L}問題轉換爲maxθL\max _{\theta}-\mathcal{L},並寫成期望形式:
θ=argmaxθExrpr()logDθ(xr)+Exfpg()log(1Dθ(xf))\theta^{*}=\underset{\theta}{\operatorname{argmax}} \mathbb{E}_{x_{r} \sim p_{r}(\cdot)} \log D_{\theta}\left(x_{r}\right)+\mathbb{E}_{x_{f} \sim p_{g}(\cdot)} \log \left(1-D_{\theta}\left(x_{f}\right)\right)
對於生成網絡G(𝒛),我們希望xf=G(z)x_{f}=G(z)能夠很好地騙過判別網絡 D,假樣本xfx_{f}在判別網絡的輸出越接近真實的標籤越好。也就是說,在訓練生成網絡時,希望判別網絡的輸出𝐷(𝐺(𝒛))越逼近 1 越好,最小化𝐷(𝐺(𝒛))與 1 之間的交叉熵損失函數:
minϕL=CE(D(Gϕ(z)),1)=logD(Gϕ(z))\min _{\phi} \mathcal{L}=C E\left(D\left(G_{\phi}(\mathbf{z})\right), 1\right)=-\log D\left(G_{\phi}(\mathbf{z})\right)
minϕL\min _{\phi} \mathcal{L}問題轉換成maxϕL\max _{\phi}-\mathcal{L},並寫成期望形式:
ϕ=argmaxϕEzpz()logD(Gϕ(z))\phi^{*}=\underset{\phi}{\operatorname{argmax}} \mathbb{E}_{\mathbf{z} \sim p_{z}(\cdot)} \log D\left(G_{\phi}(\mathbf{z})\right)
再次等價轉化爲:
ϕ=argminϕL=Ezpz()log[1D(Gϕ(z))]\phi^{*}=\underset{\phi}{\operatorname{argmin}} \mathcal{L}=\mathbb{E}_{\mathbf{z} \sim p_{z}(\cdot)} \log \left[1-D\left(G_{\phi}(\mathbf{z})\right)\right]
其中𝜙爲生成網絡 G 的參數集,可以利用梯度下降算法來優化參數𝜙。

13.2.3 統一目標函數

我們把判別網絡的目標和生成網絡的目標合併,寫成min − max博弈形式:
minϕmaxθL(D,G)=Exrpr()logDθ(xr)+Exfpg()log(1Dθ(xf))\min _{\phi} \max _{\theta} \mathcal{L}(D, G)=\mathbb{E}_{x_{r} \sim p_{r}(\cdot)} \log D_{\theta}\left(\boldsymbol{x}_{r}\right)+\mathbb{E}_{\boldsymbol{x}_{f} \sim p_{g}(\cdot)} \log \left(1-D_{\theta}\left(\boldsymbol{x}_{f}\right)\right)
=Expr(c)logDθ(x)+Ezpz()log(1Dθ(Gϕ(z)))=\mathbb{E}_{\boldsymbol{x} \sim p_{\boldsymbol{r}}(\boldsymbol{c})} \log D_{\boldsymbol{\theta}}(\boldsymbol{x})+\mathbb{E}_{\mathbf{z} \sim p_{\boldsymbol{z}}(\cdot)} \log \left(1-D_{\boldsymbol{\theta}}\left(G_{\boldsymbol{\phi}}(\mathbf{z})\right)\right)

算法流程如下:
在這裏插入圖片描述

13.3 DCGAN 實戰

本節我們來完成一個二次元動漫頭像圖片生成實戰,參考 DCGAN 的網絡結構,其中判別器 D 利用普通卷積層實現,生成器 G 利用轉置卷積層實現,如圖 13.6 所示。
在這裏插入圖片描述

13.3.1 動漫圖片數據集

這裏使用的是一組二次元動漫頭像的數據集,共 51223 張圖片,無標註信息,圖片主體已裁剪、對齊並統一縮放到96 × 96大小,部分樣片如圖 13.7 所示。

數據集下載地址:https://github.com/chenyuntc/pytorch-book/tree/master/chapter07-AnimeGAN
在這裏插入圖片描述
對於自定義的數據集,需要自行完成數據的加載和預處理工作,我們這裏聚焦在 GAN算法本身,後續自定義數據集一章會詳細介紹如何加載自己的數據集,這裏直接通過預編寫好的make_anime_dataset函數返回已經處理好的數據集對象。代碼如下:

img_path = glob.glob(r'C:\Users\z390\Downloads\faces\*.jpg')
# 構建數據集對象,返回數據集 Dataset 類和圖片大小
dataset, img_shape, _ = make_anime_dataset(img_path, batch_size,resize=64)

其中 dataset 對象就是 tf.data.Dataset 類實例,已經完成了隨機打散、預處理和批量化等操作,可以直接迭代獲得樣本批,img_shape 是預處理後的圖片大小。

13.3.2 生成器

生成網絡 G 由 5 個轉置卷積層單元堆疊而成,實現特徵圖高寬的層層放大,特徵圖通道數的層層減少。首先將長度爲 100 的隱藏向量𝒛通過 Reshape 操作調整爲[𝑏, 1,1,100]的 4維張量,並依序通過轉置卷積層,放大高寬維度,減少通道數維度,最後得到高寬爲 64,通道數爲 3 的彩色圖片。每個卷積層中間插入 BN 層來提高訓練穩定性,卷積層選擇不使用偏置向量。生成器的類代碼實現如下:

class Genetator(keras.Model):
    #生成器網絡類
    def __init__(self):
        super(Genetator,self).__init__()
        filter=64
        #轉置卷積層1,輸出channel爲filter*8,核大小4,步長1,不使用padding,不使用偏置
        self.conv1=layers.Conv2DTranspose(filter*8,4,1,'valid',use_bias=False)
        self.bn1=layers.BatchNormalization()
        #轉置卷積層2
        self.conv2=layers.Conv2DTranspose(filter*4,4,2,'same',use_bias=False)
        self.bn2=layers.BatchNormalization()
        #轉置卷積層3
        self.conv3=layers.Conv2DTranspose(filter*2,4,2,'same',use_bias=False)
        self.bn3=layers.BatchNormalization()
        #轉置卷積層4
        self.conv4=layers.Conv2DTranspose(filter*1,4,2,'same',use_bias=False)
        self.bn4=layers.BatchNormalization()
        #轉置卷積層5
        self.conv5=layers.Conv2DTranspose(3,4,2,'same',use_bias=False)

生成網絡 G 的前向傳播過程實現如下:

    def call(self,inputs,training=None):
        x=inputs#[z,100]
        # reshape成4D張量,方便後續轉換卷積運算:(b,1,1,100)
        x=tf.reshape(x,(x.shape[0],1,1,x.shape[1]))
        x=tf.nn.relu(x)#激活函數
        #轉置卷積-BN-激活函數:(b,4,4,512)
        x=tf.nn.relu(self.bn1(self.conv1(x),training=training))
        # 轉置卷積-BN-激活函數:(b,8,8,256)
        x=tf.nn.relu(self.bn2(self.conv2(x),training=training))
        # 轉置卷積-BN-激活函數:(b,16,16,128)
        x=tf.nn.relu(self.bn3(self.conv3(x), training=training))
        # 轉置卷積-BN-激活函數:(b,32,32,64)
        x=tf.nn.relu(self.bn4(self.conv4(x), training=training))
        # 轉置卷積-BN-激活函數:(b,64,64,3)
        x = self.conv5(x)
        x=tf.tanh(x)#輸出x範圍[-1~1],與預處理一致

        return x

生成網絡的輸出大小爲[𝑏, 64,64,3]的圖片張量,數值範圍爲[−1~1]。

13.3.3 判別器

判別網絡 D 與普通的分類網絡相同,接受大小爲[𝑏, 64,64,3]的圖片張量,連續通過 5個卷積層實現特徵的層層提取,卷積層最終輸出大小爲[𝑏, 2,2,1024],再通過池化層GlobalAveragePooling2D將特徵大小轉換爲[𝑏, 1024],最後通過一個全連接層獲得二分類任務的概率。判別網絡 D 類的代碼實現如下:

class Discriminator(keras.Model):
    #判別器類
    def __init__(self):
        super(Discriminator,self).__init__()
        filter=64
        #卷積層1
        self.conv1=layers.Conv2D(filter,4,2,'valid',use_bias=False)
        self.bn1=layers.BatchNormalization()
        # 卷積層2
        self.conv2=layers.Conv2D(filter*2, 4, 2, 'valid', use_bias=False)
        self.bn2 = layers.BatchNormalization()
        # 卷積層3
        self.conv3 = layers.Conv2D(filter*4, 4, 2, 'valid', use_bias=False)
        self.bn3 = layers.BatchNormalization()
        # 卷積層4
        self.conv4 = layers.Conv2D(filter*8, 3, 1, 'valid', use_bias=False)
        self.bn4 = layers.BatchNormalization()
        # 卷積層5
        self.conv5= layers.Conv2D(filter*16, 3, 1, 'valid', use_bias=False)
        self.bn5 = layers.BatchNormalization()
        #全局池化層
        self.pool=layers.GlobalAveragePooling2D()
        #特徵打平層
        self.flatten=layers.Flatten()
        #2分類全連接層
        self.fc=layers.Dense(1)

判別器 D 的前向計算過程實現如下:

    def call(self,inputs,training=None):
        #卷積-BN-激活函數:(b,31,31,64)
        x=tf.nn.leaky_relu(self.bn1(self.conv1(inputs),training=training))
        # 卷積-BN-激活函數:(b,14,14,128)
        x = tf.nn.leaky_relu(self.bn2(self.conv2(inputs), training=training))
        # 卷積-BN-激活函數:(b,6,6,256)
        x = tf.nn.leaky_relu(self.bn3(self.conv3(inputs), training=training))
        # 卷積-BN-激活函數:(b,4,4,512)
        x = tf.nn.leaky_relu(self.bn4(self.conv4(inputs), training=training))
        # 卷積-BN-激活函數:(b,2,2,1024)
        x = tf.nn.leaky_relu(self.bn5(self.conv5(inputs), training=training))
        # 卷積-BN-激活函數:(b,1024)
        x = self.pool(x)
        #打平
        x=self.flatten(x)
        #輸出,[b,1024]=>[b,1]
        logits=self.fc(x)
        return logits

判別器的輸出大小爲[𝑏, 1],類內部沒有使用 Sigmoid 激活函數,通過 Sigmoid 激活函數後可獲得𝑏個樣本屬於真實樣本的概率。

13.3.4 訓練與可視化

判別網絡 根據上述公式,判別網絡的訓練目標是最大化ℒ(𝐷, 𝐺)函數,使得真實樣本預測爲真的概率接近於 1,生成樣本預測爲真的概率接近於 0。我們將判斷器的誤差函數實現在 d_loss_fn 函數中,將所有真實樣本標註爲 1,所有生成樣本標註爲 0,並通過最小化對應的交叉熵損失函數來實現最大化ℒ(𝐷,𝐺)函數。d_loss_fn 函數實現如下:

def d_loss_fn(generator,discriminator,batch_z,batch_x,is_training):
    #計算判別器的誤差函數
    #採樣生成圖片
    fake_image=generator(batch_z,is_training)
    #判定生成圖片
    d_fake_logits=discriminator(fake_image,is_training)
    #判定真實圖片
    d_real_logits=discriminator(batch_x,is_training)
    #真實圖片與1之間的誤差
    d_loss_real=celoss_ones(d_real_logits)
    #生成圖片與0之間的誤差
    d_loss_fake=celoss_zeros(d_fake_logits)
    #合併誤差
    loss=d_loss_fake+d_loss_real
    
    return loss

其中 celoss_ones 函數計算當前預測概率與標籤 1 之間的交叉熵損失,代碼如下:

def celoss_ones(logits):
    #計算屬於與標籤爲1的交叉熵
    y=tf.ones_like(logits)
    loss=keras.losses.binary_crossentropy(y,logits,from_logits=True)
    return tf.reduce_mean(loss)

celoss_zeros 函數計算當前預測概率與標籤 0 之間的交叉熵損失,代碼如下:

def celoss_zeros(logits):
    #計算屬於與便籤爲0的交叉熵
    y=tf.zeros_like(logits)
    loss=keras.losses.binary_crossentropy(y,logits,from_logits=True)
    return tf.reduce_men(loss)

生成網絡 的訓練目標是最小化ℒ(𝐷, 𝐺)目標函數,由於真實樣本與生成器無關,因此誤差函數只需要考慮最小化Ezpz()log(1Dθ(Gϕ(z)))\mathbb{E}_{\mathbf{z} \sim p_{\mathbf{z}}(\cdot)} \log \left(1-D_{\theta}\left(G_{\phi}(\mathbf{z})\right)\right)項即可。可以通過將生成的樣本標註爲 1,最小化此時的交叉熵誤差。需要注意的是,在反向傳播誤差的過程中,判別器也參與了計算圖的構建,但是此階段只需要更新生成器網絡參數,而不更新判別器的網絡參數。

生成器的誤差函數代碼如下:

def g_loss_fn(generator,discriminator,batch_z,is_training):
    # 採樣生成圖片
    fake_image=generator(batch_z,is_training)
    #在訓練生成網絡時,需要迫使生成圖片判定爲真
    d_fake_logits=discriminator(fake_image,is_training)
    #計算生成圖片與1之間的誤差
    loss=celoss_ones(d_fake_logits)

    return loss

網絡訓練 在每個 Epoch,首先從先驗分佈pz()p_{z}(\cdot)中隨機採樣隱藏向量,從真實數據集中隨機採樣真實圖片,通過生成器和判別器計算判別器網絡的損失,並優化判別器網絡參數𝜃。在訓練生成器時,需要藉助於判別器來計算誤差,但是隻計算生成器的梯度信息並更新𝜙。這裏設定判別器訓練𝑘 = 5次後,生成器訓練一次。

首先創建生成網絡和判別網絡,並分別創建對應的優化器。代碼如下:

z_dim = 100  # 隱藏向量z的長度
learning_rate = 0.0002
is_training = True
generator=Genetator()#創建生成器
generator.build(input_shape=(4,z_dim))
discriminator=Discriminator()#創建判別器
discriminator.build(input_shape=(4,64,64,3))
#分別爲生成器和判別器創建優化器
g_optimizer=keras.optimizers.Adam(learning_rate=learning_rate,beta_l=0.5)
d_optimizer=keras.optimizers.Adam(learning_rate=learning_rate,beta_l=0.5)

主訓練部分代碼實現如下:

epochs = 3000000  # 訓練步數
batch_size = 64  # batch size
for epoch in range(epochs):
    #1.訓練判別器
    for _ in range(5):
        #採樣隱藏向量
        batch_z=tf.random.normal([batch_size,z_dim])#[64,100]
        batch_x=next(db_iter)#採樣真實圖片
        #判別器前向計算
        with tf.GradientTape() as tape:
            d_loss=d_loss_fn(generator,discriminator,batch_z,batch_x,is_training)
        grads=tape.gradient(d_loss,discriminator.training_variables)
        d_optimizer.apply_gradients(zip(grads,discriminator.trainable_variables))

    #2.訓練生成器
    # 採樣隱藏向量
    batch_z=tf.random.normal([batch_size,z_dim])
    batch_x=next(db_iter)
    #生成器前向計算
    with tf.GradientTape() as tape:
        g_loss=g_loss_fn(generator,discriminator,batch_z,is_training)
    grads=tape.gradient(g_loss,generator.trainable_varibales)
    g_optimizer.apply_gradients(zip(grads,generator.trainable_variables))  

每間隔 100 個 Epoch,進行一次圖片生成測試。通過從先驗分佈中隨機採樣隱向量,送入生成器獲得生成圖片,並保存爲文件。

如圖 13.8 所示,展示了 DCGAN 模型在訓練過程中保存的生成圖片樣例,可以觀察到,大部分圖片主體明確,色彩逼真,圖片多樣性較豐富,圖片效果較爲貼近數據集中真實的圖片。同時也能發現仍有少量生成圖片損壞,無法通過人眼辨識圖片主體。
在這裏插入圖片描述

13.4 GAN 變種

在原始的 GAN 論文中,Ian Goodfellow 從理論層面分析了 GAN 網絡的收斂性,並且在多個經典圖片數據集上測試了圖片生成的效果,如圖 13.9 所示,其中圖 13.9 (a)爲MNIST 數據,圖 13.9 (b)爲 Toronto Face 數據集,圖 13.9 ©、圖 13.9 (d)爲 CIFAR10 數據集。
在這裏插入圖片描述
可以看到,原始 GAN 模型在圖片生成效果上並不突出,和 VAE 差別不明顯,此時並沒有展現出它強大的分佈逼近能力。但是由於 GAN 在理論方面較新穎,實現方面也有很多可以改進的地方,大大地激發了學術界的研究興趣。在接下來的數年裏,GAN 的研究如火如荼的進行,並且也取得了實質性的進展。接下來我們將介紹幾個意義比較重大的 GAN變種。

13.4.1 DCGAN

最初始的 GAN 網絡主要基於全連接層實現生成器 G 和判別器 D 網絡,由於圖片的維度較高,網絡參數量巨大,訓練的效果並不優秀。DCGAN提出了使用轉置卷積層實現的生成網絡,普通卷積層來實現的判別網絡,大大地降低了網絡參數量,同時圖片的生成效果也大幅提升,展現了 GAN 模型在圖片生成效果上超越 VAE 模型的潛質。此外,DCGAN 作者還提出了一系列經驗性的 GAN 網絡訓練技巧,這些技巧在 WGAN 提出之前被證實有益於網絡的穩定訓練。前面我們已經使用 DCGAN 模型完成了二次元動漫頭像的圖片生成實戰。

13.4.2 InfoGAN

InfoGAN 嘗試使用無監督的方式去學習輸入𝒙的可解釋隱向量𝒛的表示方法(Interpretable Representation),即希望隱向量𝒛能夠對應到數據的語義特徵。比如對於MNIST 手寫數字圖片,我們可以認爲數字的類別、字體大小和書寫風格等是圖片的隱藏變量,希望模型能夠學習到這些分離的(Disentangled)可解釋特徵表示方法,從而可以通過人爲控制隱變量來生成指定內容的樣本。對於 CelebA 名人照片數據集,希望模型可以把髮型、眼鏡佩戴情況、面部表情等特徵分隔開,從而生成指定形態的人臉圖片。

分離的可解釋特徵有什麼好處呢?它可以讓神經網絡的可解釋性更強,比如𝒛包含了一些分離的可解釋特徵,那麼我們可以通過僅僅改變這一個位置上面的特徵來獲得不同語義的生成數據,如圖 13.10 所示,通過將“戴眼鏡男士”與“不戴眼鏡男士”的隱向量相減,並與“不戴眼鏡女士”的隱向量相加,可以生成“戴眼鏡女士”的生成圖片。
在這裏插入圖片描述

13.4.3 CycleGAN

CycleGAN是華人朱儁彥提出的無監督方式進行圖片風格相互轉換的算法,由於算法清晰簡單,實驗效果完成的較好,這項工作受到了很多的讚譽。CycleGAN 基本的假設是,如果由圖片 A 轉換到圖片 B,再從圖片 B 轉換到A′,那麼A′應該和 A 是同一張圖片。因此除了設立標準的 GAN 損失項外,CycleGAN 還增設了循環一致性損失(CycleConsistency Loss),來保證A′儘可能與 A 逼近。CycleGAN 圖片的轉換效果如圖 13.11 所示。
在這裏插入圖片描述

13.4.4 WGAN

GAN 的訓練問題一直被詬病,很容易出現訓練不收斂和模式崩塌的現象。WGAN從理論層面分析了原始的 GAN 使用 JS 散度存在的缺陷,並提出了可以使用 Wasserstein 距離來解決這個問題。在 WGAN-GP中,作者提出了通過添加梯度懲罰項,從工程層面很好的實現了 WGAN 算法,並且實驗性證實了 WGAN 訓練穩定的優點。

13.4.5 Equal GAN

從 GAN 的誕生至 2017 年底,GAN Zoo 已經收集超過了 214 種 GAN 網絡變種。這些 GAN 的變種或多或少地提出了一些創新,然而 Google Brain 的幾位研究員在論文中提供了另一個觀點:沒有證據表明我們測試的 GAN 變種算法一直持續地比最初始的 GAN要好。論文中對這些 GAN 變種進行了相對公平、全面的比較,在有足夠計算資源的情況下,發現幾乎所有的 GAN 變種都能達到相似的性能(FID 分數)。這項工作提醒業界是否這些 GAN 變種具有本質上的創新。

13.4.6 Self-Attention GAN

Attention 機制在自然語言處理(NLP)中間已經用得非常廣泛了,Self-Attention GAN(SAGAN) 借鑑了 Attention 機制,提出了基於自注意力機制的 GAN 變種。SAGAN 把圖片的逼真度指標:Inception score,從最好的 36.8 提升到 52.52,Frechet Inception distance,從 27.62 降到 18.65。從圖片生成效果上來看,SAGAN 取得的突破是十分顯著的,同時也啓發業界對自注意力機制的關注。
在這裏插入圖片描述

13.4.7 BigGAN

在 SAGAN 的基礎上,BigGAN嘗試將 GAN 的訓練擴展到大規模上去,利用正交正則化等技巧保證訓練過程的穩定性。BigGAN 的意義在於啓發人們,GAN 網絡的訓練同樣可以從大數據、大算力等方面受益。BigGAN 圖片生成效果達到了前所未有的高度:Inception score 記錄提升到 166.5(提高了 52.52);Frechet Inception Distance 下降到 7.4,降低了 18.65,如圖 13.13 所示,圖片的分辨率可達512 × 512,圖片細節極其逼真。
在這裏插入圖片描述

13.5 納什均衡

現在我們從理論層面進行分析,通過博弈學習的訓練方式,生成器 G判別器 D 分別會達到什麼平衡狀態。具體地,我們將探索以下兩個問題:

❑ 固定 G,D 會收斂到什麼最優狀態𝐷∗?
❑ 在 D 達到最優狀態𝐷∗後,G 會收斂到什麼狀態?

首先我們通過xrpr()\boldsymbol{x}_{r} \sim p_{r}(\cdot)一維正態分佈的例子給出一個直觀的解釋。如圖 13.14 所示,黑色虛線曲線代表了真實數據的分佈pr()p_{r}(\cdot),爲某正態分佈𝒩(𝜇, 𝜎2),綠色實線代表了生成網絡學習到的分佈xfpg()\boldsymbol{x}_{f} \sim p_{g}(\cdot),藍色虛線代表了判別器的決策邊界曲線,圖 13.14 (a)、(b)、( c )、(d)分別代表了生成網絡的學習軌跡。

在初始狀態,如圖 13.14(a)所示,pg()p_{g}(\cdot)分佈與pr()p_{r}(\cdot)差異較大,判別器可以很輕鬆地學習到明確的決策邊界,即圖 13.14(a)中的藍色虛線,將來自pg()p_{g}(\cdot)的採樣點判定爲 0,pr()p_{r}(\cdot)中的採樣點判定爲 1。隨着生成網絡的分佈pg()p_{g}(\cdot)越來越逼近真實分佈pr()p_{r}(\cdot),判別器越來越困難將真假樣本區分開,如圖 13.14(b)( c )所示。最後,生成網絡學習到的分佈pg()p_{g}(\cdot)= pr()p_{r}(\cdot)時,此時從生成網絡中採樣的樣本非常逼真,判別器無法區分,即判定爲真假樣本的概率均等,如圖 13.14(d)所示。

這個例子直觀地解釋了 GAN 網絡的訓練過程。
在這裏插入圖片描述

13.5.1 判別器狀態

現在來推導第一個問題。回顧 GAN 的損失函數:
L(G,D)=xpr(x)log(D(x))dx+zpz(z)log(1D(g(z)))dz=xpr(x)log(D(x))+pg(x)log(1D(x))dx\begin{aligned} \mathcal{L}(G, D) &=\int_{x} p_{r}(\boldsymbol{x}) \log (D(\boldsymbol{x})) d \boldsymbol{x}+\int_{\boldsymbol{z}} p_{\boldsymbol{z}}(\mathbf{z}) \log (1-D(g(\boldsymbol{z}))) d \boldsymbol{z} \\ &=\int_{\boldsymbol{x}} p_{r}(\boldsymbol{x}) \log (D(\boldsymbol{x}))+p_{g}(\boldsymbol{x}) \log (1-D(\boldsymbol{x})) d \boldsymbol{x} \end{aligned}

對於判別器 D,優化的目標是最大化ℒ(𝐺,𝐷)函數,需要找出函數:
fθ=pr(x)log(D(x))+pg(x)log(1D(x))f_{\theta}=p_{r}(x) \log (D(x))+p_{g}(x) \log (1-D(x))

的最大值,其中𝜃爲判別器𝐷的網絡參數。

我們來考慮fθf_{\theta}更通用的函數的最大值情況:
f(x)=Alogx+Blog(1x)f(x)=A \log x+B \log (1-x)

要求得函數𝑓(𝑥)的最大值。考慮𝑓(𝑥)的導數:
df(x)dx=A1ln101xB1ln1011x=1ln10(AxB1x)=1ln10A(A+B)xx(1x)\begin{aligned} &\frac{\mathrm{d} f(x)}{\mathrm{d} x}=A \frac{1}{\ln 10} \frac{1}{x}-B \frac{1}{\ln 10} \frac{1}{1-x}\\ &=\frac{1}{\ln 10}\left(\frac{A}{x}-\frac{B}{1-x}\right)\\ &=\frac{1}{\ln 10} \frac{A-(A+B) x}{x(1-x)} \end{aligned}

df(x)dx=0\frac{\mathrm{d} f(x)}{\mathrm{d} x}=0,我們可以求得𝑓(𝑥)函數的極值點:
x=AA+Bx=\frac{A}{A+B}

因此,可以得知,fθf_{\theta}函數的極值點同樣爲:
Dθ=pr(x)pr(x)+pg(x)D_{\theta}=\frac{p_{r}(\boldsymbol{x})}{p_{r}(\boldsymbol{x})+p_{g}(\boldsymbol{x})}

也就是說,判別器網絡DθD_{\theta}處於DθD_{\theta^{*}}狀態時,fθf_{\theta}函數取得最大值,ℒ(𝐺, 𝐷)函數也取得最大值。

現在回到最大化ℒ(𝐺,𝐷)的問題,ℒ(𝐺,𝐷)的最大值點在:
D=AA+B=pr(x)pr(x)+pg(x)D^{*}=\frac{A}{A+B}=\frac{p_{r}(x)}{p_{r}(x)+p_{g}(x)}
時取得,此時也是𝐷𝜃的最優狀態𝐷∗。

13.5.2 生成器狀態

在推導第二個問題之前,我們先介紹一下與 KL 散度類似的另一個分佈距離度量標準:JS 散度,它定義爲 KL 散度的組合:
DKL(pq)=xp(x)logp(x)q(x)dxDJS(pq)=12DKL(pp+q2)+12DKL(qp+q2)\begin{array}{c} D_{K L}(p \| q)=\int_{x} p(x) \log \frac{p(x)}{q(x)} d x \\ D_{J S}(p \| q)=\frac{1}{2} D_{K L}\left(p \| \frac{p+q}{2}\right)+\frac{1}{2} D_{K L}\left(q \| \frac{p+q}{2}\right) \end{array}

JS 散度克服了 KL 散度不對稱的缺陷。

當 D 達到最優狀態𝐷∗時,我們來考慮此時prp_{r}pgp_{g}的 JS 散度:
DJS(prpg)=12DKL(prpr+pg2)+12DKL(pgpr+pg2)D_{J S}\left(p_{r} \| p_{g}\right)=\frac{1}{2} D_{K L}\left(p_{r} \| \frac{p_{r}+p_{g}}{2}\right)+\frac{1}{2} D_{K L}\left(p_{g} \| \frac{p_{r}+p_{g}}{2}\right)
根據 KL 散度的定義展開爲:
DJS(prpg)=12(log2+xpr(x)logpr(x)pr+pg(x)dx)+12(log2+xpg(x)logpg(x)pr+pg(x)dx)\begin{aligned} D_{J S}\left(p_{r}|| p_{g}\right) &=\frac{1}{2}\left(\log 2+\int_{x} p_{r}(x) \log \frac{p_{r}(x)}{p_{r}+p_{g}(x)} d x\right) \\ &+\frac{1}{2}\left(\log 2+\int_{x} p_{g}(x) \log \frac{p_{g}(x)}{p_{r}+p_{g}(x)} d x\right) \end{aligned}
合併常數項可得:
DJS(prpg)=12(log2+log2)D_{J S}\left(p_{r} \| p_{g}\right)=\frac{1}{2}(\log 2+\log 2)
+12(xpr(x)logpr(x)pr+pg(x)dx+xpg(x)logpg(x)pr+pg(x)dx)+\frac{1}{2}\left(\int_{x} p_{r}(x) \log \frac{p_{r}(x)}{p_{r}+p_{g}(x)} d x+\int_{x} p_{g}(x) \log \frac{p_{g}(x)}{p_{r}+p_{g}(x)} d x\right)
即:
DJS(prpg)=12(log4)+12(xpr(x)logpr(x)pr+pg(x)dx+xpg(x)logpg(x)pr+pg(x)dx)\begin{array}{c} D_{J S}\left(p_{r}|| p_{g}\right)=\frac{1}{2}(\log 4) \\ +\frac{1}{2}\left(\int_{x} p_{r}(x) \log \frac{p_{r}(x)}{p_{r}+p_{g}(x)} d x+\int_{x} p_{g}(x) \log \frac{p_{g}(x)}{p_{r}+p_{g}(x)} d x\right) \end{array}

考慮在判別網絡到達𝐷∗時,此時的損失函數爲:
L(G,D)=xpr(x)log(D(x))+pg(x)log(1D(x))dx=xpr(x)logpr(x)pr+pg(x)dx+xpg(x)logpg(x)pr+pg(x)dx\begin{array}{l} \mathcal{L}\left(G, D^{*}\right)=\int_{x} p_{r}(\boldsymbol{x}) \log \left(D^{*}(\boldsymbol{x})\right)+p_{g}(\boldsymbol{x}) \log \left(1-D^{*}(\boldsymbol{x})\right) d \boldsymbol{x} \\ =\int_{\boldsymbol{x}} p_{r}(\boldsymbol{x}) \log \frac{p_{r}(\boldsymbol{x})}{p_{r}+p_{g}(\boldsymbol{x})} d \boldsymbol{x}+\int_{\boldsymbol{x}} p_{g}(\boldsymbol{x}) \log \frac{p_{g}(\boldsymbol{x})}{p_{r}+p_{g}(\boldsymbol{x})} d \boldsymbol{x} \end{array}
因此在判別網絡到達𝐷∗時,DJS(prpg)D_{J S}\left(p_{r} \| p_{g}\right)L(G,D)\mathcal{L}\left(G, D^{*}\right)滿足關係:
DJS(prpg)=12(log4+L(G,D))D_{J S}\left(p_{r} \| p_{g}\right)=\frac{1}{2}\left(\log 4+\mathcal{L}\left(G, D^{*}\right)\right)
即:
L(G,D)=2DJS(prpg)2log2\mathcal{L}\left(G, D^{*}\right)=2 D_{J S}\left(p_{r} \| p_{g}\right)-2 \log 2
對於生成網絡 G 而言,訓練目標是minGL(G,D)\min _{G} \mathcal{L}(G, D) ,考慮到 JS 散度具有性質:
DJS(prpg)0D_{J S}\left(p_{r} \| p_{g}\right) \geq 0
因此L(G,D)\mathcal{L}\left(G, D^{*}\right)取得最小值僅在DJS(prpg)=0D_{J S}\left(p_{r} \| p_{g}\right)=0時(此時pgp_{g}=prp_{r}),L(G,D)\mathcal{L}\left(G, D^{*}\right)取得最小值:
L(G,D)=2log2\mathcal{L}\left(G^{*}, D^{*}\right)=-2 \log 2
此時生成網絡𝐺∗的狀態是:
pg=prp_{g}=p_{r}
即𝐺∗的學到的分佈pgp_{g}與真實分佈prp_{r}一致,網絡達到平衡點,此時:
D=pr(x)pr(x)+pg(x)=0.5D^{*}=\frac{p_{r}(\boldsymbol{x})}{p_{r}(\boldsymbol{x})+p_{g}(\boldsymbol{x})}=0.5

13.5.3 納什均衡點

通過上面的推導,我們可以總結出生成網絡 G 最終將收斂到真實分佈,即:
pg=prp_{g}=p_{r}
此時生成的樣本與真實樣本來自同一分佈,真假難辨,在判別器中均有相同的概率判定爲
真或假,即
D()=0.5D(\cdot)=0.5
此時損失函數爲
L(G,D)=2log2\mathcal{L}\left(G^{*}, D^{*}\right)=-2 \log 2

13.6 GAN 訓練難題

儘管從理論層面分析了 GAN 網絡能夠學習到數據的真實分佈,但是在工程實現中,常常出現 GAN 網絡訓練困難的問題,主要體現在 GAN 模型對超參數較爲敏感,需要精心挑選能使模型工作的超參數設定,同時也容易出現模式崩塌現象。

13.6.1 超參數敏感

超參數敏感是指網絡的結構設定、學習率、初始化狀態等超參數對網絡的訓練過程影響較大,微量的超參數調整將可能導致網絡的訓練結果截然不同。如圖 13.15 所示,圖(a)爲 GAN 模型良好訓練得到的生成樣本,圖(b)中的網絡由於沒有采用 Batch Normalization層等設置,導致 GAN 網絡訓練不穩定,無法收斂,生成的樣本與真實樣本差距非常大。

爲了能較好地訓練 GAN 網絡,DCGAN 論文作者提出了不使用 Pooling 層、多使用Batch Normalization 層、不使用全連接層、生成網絡中激活函數應使用 ReLU、最後一層使用tanh激活函數、判別網絡激活函數應使用 LeakyLeLU 等一系列經驗性的訓練技巧。但是這些技巧僅能在一定程度上避免出現訓練不穩定的現象,並沒有從理論層面解釋爲什麼會出現訓練困難、以及如果解決訓練不穩定的問題。
在這裏插入圖片描述

13.6.2 模式崩塌

模式崩塌(Mode Collapse)是指模型生成的樣本單一,多樣性很差的現象。由於判別器只能鑑別單個樣本是否採樣自真實分佈,並沒有對樣本多樣性進行顯式約束,導致生成模型可能傾向於生成真實分佈的部分區間中的少量高質量樣本,以此來在判別器中獲得較高的概率值,而不會學習到全部的真實分佈。模式崩塌現象在 GAN 中比較常見,如圖 13.16所示,在訓練過程中,通過可視化生成網絡的樣本可以觀察到,生成的圖片種類非常單一,生成網絡總是傾向於生成某種單一風格的樣本圖片,以此騙過判別器。
在這裏插入圖片描述
另一個直觀地理解模式崩塌的例子如圖 13.17 所示,第一行爲未出現模式崩塌現象的生成網絡的訓練過程,最後一列爲真實分佈,即 2D 高斯混合模型;第二行爲出現模式崩塌現象的生成網絡的訓練過程,最後一列爲真實分佈。可以看到真實的分佈由 8 個高斯模型混合而成,出現模式崩塌後,生成網絡總是傾向於逼近真實分佈的某個狹窄區間,如圖13.17 第 2 行前 6 列所示,從此區間採樣的樣本往往能夠在判別器中較大概率判斷爲真實樣本,從而騙過判別器。但是這種現象並不是我們希望看到的,我們希望生成網絡能夠逼近真實的分佈,而不是真實分佈中的某部分
在這裏插入圖片描述
那麼怎麼解決 GAN 訓練的難題,讓 GAN 可以像普通的神經網絡一樣訓練較爲穩定呢?WGAN 模型給出了一種解決方案。

13.7 WGAN 原理

WGAN 算法從理論層面分析了 GAN 訓練不穩定的原因,並提出了有效的解決方法。那麼是什麼原因導致了 GAN 訓練如此不穩定呢?WGAN 提出是因爲 JS 散度在不重疊的分佈𝑝和𝑞上的梯度曲面是恆定爲 0 的。如圖 13.19 所示,當分佈𝑝和𝑞不重疊時,JS 散度的梯度值始終爲 0,從而導致此時 GAN 的訓練出現梯度彌散現象,參數長時間得不到更新,網絡無法收斂。
在這裏插入圖片描述
接下來我們將詳細闡述 JS 散度的缺陷以及怎麼解決此缺陷。

13.7.1 JS 散度的缺陷

爲了避免過多的理論推導,我們這裏通過一個簡單的分佈實例來解釋 JS 散度的缺陷。考慮完全不重疊(𝜃 ≠ 0)的兩個分佈𝑝和𝑞,其中分佈𝑝爲:
(x,y)p,x=0,yU(0,1)\forall(x, y) \in p, x=0, y \sim \mathrm{U}(0,1)
分佈𝑞爲:
(x,y)q,x=θ,yU(0,1)\forall(x, y) \in q, x=\theta, y \sim \mathrm{U}(0,1)
其中𝜃 ∈ 𝑅,當𝜃 = 0時,分佈𝑝和𝑞重疊,兩者相等;當𝜃 ≠ 0時,分佈𝑝和𝑞不重疊。
在這裏插入圖片描述
我們來分析上述分佈𝑝和𝑞之間的 JS 散度隨𝜃的變化情況。根據 KL 散度與 JS 散度的定義,計算𝜃 = 0時的 JS 散度𝐷𝐽𝑆(𝑝||𝑞):
DKL(pq)=x=0,yU(0,1)1log10=+D_{K L}(p \| q)=\sum_{x=0, y \sim U(0,1)} 1 \cdot \log \frac{1}{0}=+\infty
DKL(qp)=x=θ,yU(0,1)1log10=+D_{K L}(q \| p)=\sum_{x=\theta, y \sim \mathrm{U}(0,1)} 1 \cdot \log \frac{1}{0}=+\infty
DJS(pq)=12(x=0,yU(0,1)1log11/2+x=0,yU(0,1)1log11/2)=log2D_{J S}(p \| q)=\frac{1}{2}\left(\sum_{x=0, y \sim U(0,1)} 1 \cdot \log \frac{1}{1 / 2}+\sum_{x=0, y \sim U(0,1)} 1 \cdot \log \frac{1}{1 / 2}\right)=\log 2
當𝜃 = 0時,兩個分佈完全重疊,此時的 JS 散度和 KL 散度都取得最小值,即 0:
DKL(pq)=DKL(qp)=DJS(pq)=0D_{K L}(p \| q)=D_{K L}(q \| p)=D_{J S}(p \| q)=0
從上面的推導,我們可以得到DJS(pq)\left|D_{J S}(p \| q)\right.隨𝜃的變化趨勢:
DJS(pq)={log2θ00θ=0D_{J S}(p \| q)=\left\{\begin{array}{cl} \log 2 & \theta \neq 0 \\ 0 & \theta=0 \end{array}\right.
也就是說,當兩個分佈完全不重疊時,無論分佈之間的距離遠近,JS 散度爲恆定值log 2,此時 JS 散度將無法產生有效的梯度信息;當兩個分佈出現重疊時,JS 散度纔會平滑變動,產生有效梯度信息;當完全重合後,JS 散度取得最小值 0。如圖 13.19 中所示,紅色的曲線分割兩個正態分佈,由於兩個分佈沒有重疊,生成樣本位置處的梯度值始終爲 0,無法更新生成網絡的參數,從而出現網絡訓練困難的現象。
在這裏插入圖片描述
因此,JS 散度在分佈𝑝和𝑞不重疊時是無法平滑地衡量分佈之間的距離,從而導致此位置上無法產生有效梯度信息,出現 GAN 訓練不穩定的情況。要解決此問題,需要使用一種更好的分佈距離衡量標準,使得它即使在分佈𝑝和𝑞不重疊時,也能平滑反映分佈之間的真實距離變化。

13.7.2 EM 距離

WGAN 論文發現了 JS 散度導致 GAN 訓練不穩定的問題,並引入了一種新的分佈距離度量方法:Wasserstein 距離,也叫推土機距離(Earth-Mover Distance,簡稱 EM 距離),它表示了從一個分佈變換到另一個分佈的最小代價,定義爲:
W(p,q)=infγΠ(p,q)E(x,y)γ[xy]W(p, q)=\inf _{\gamma \sim \Pi(p, q)} \mathbb{E}_{(x, y) \sim \gamma}[\|x-y\|]
其中∏(𝑝, 𝑞)是分佈𝑝和𝑞組合起來的所有可能的聯合分佈的集合,對於每個可能的聯合分佈𝛾 ∼ ∏(𝑝, 𝑞),計算距離‖𝑥 − 𝑦‖的期望𝔼(𝑥,𝑦)∼𝛾[‖𝑥 − 𝑦‖],其中(𝑥, 𝑦)採樣自聯合分佈𝛾。不同的聯合分佈𝛾有不同的期望𝔼(𝑥,𝑦)∼𝛾[‖𝑥 − 𝑦‖],這些期望中的下確界即定義爲分佈𝑝和𝑞的Wasserstein 距離。其中inf{∙}表示集合的下確界,例如{𝑥|1 < 𝑥 < 3, 𝑥 ∈ 𝑅}的下確界爲 1。

繼續考慮圖 13.18 中的例子,我們直接給出分佈𝑝和𝑞之間的 EM 距離的表達式:
W(p,q)=θW(p, q)=|\theta|
繪製出 JS 散度和 EM 距離的曲線,如圖 13.20 所示,可以看到,JS 散度在𝜃 = 0處不連續,其他位置導數均爲 0,而 EM 距離總能夠產生有效的導數信息,因此 EM 距離相對於JS 散度更適合指導 GAN 網絡的訓練。
在這裏插入圖片描述

13.7.3 WGAN-GP

考慮到幾乎不可能遍歷所有的聯合分佈𝛾去計算距離‖𝑥 − 𝑦‖的期望𝔼(𝑥,𝑦)∼𝛾[‖𝑥 − 𝑦‖],因此直接計算生成網絡分佈pgp_{g}與真實數據分佈prp_{r}W(pr,pg)W\left(p_{r}, p_{g}\right)距離是不現實的,WGAN 作者基於 Kantorovich-Rubinstein 對偶性將直接求W(pr,pg)W\left(p_{r}, p_{g}\right)轉換爲求:
W(pr,pg)=1KsupfLKExpr[f(x)]Expg[f(x)]W\left(p_{r}, p_{g}\right)=\frac{1}{K} \underset{\|f\|_{L} \leq K}{\sup } \mathbb{E}_{x \sim p_{r}}[f(x)]-\mathbb{E}_{x \sim p_{g}}[f(x)]
其中𝑠𝑢𝑝{∙}表示集合的上確界,||𝑓||𝐿 ≤ 𝐾表示函數𝑓: 𝑅 → 𝑅滿足 K-階 Lipschitz 連續性,即滿足
f(x1)f(x2)Kx1x2\left|f\left(x_{1}\right)-f\left(x_{2}\right)\right| \leq K \cdot\left|x_{1}-x_{2}\right|
於是,我們使用判別網絡Dθ(x)D_{\theta}(x)參數化𝑓(𝒙)函數,在DθD_{\theta}滿足 1 階-Lipschitz 約束的條件下,即𝐾 = 1,此時:
W(pr,pg)=supDθL1Expr[Dθ(x)]Expg[Dθ(x)]W\left(p_{r}, p_{g}\right)=\sup _{\left\|D_{\theta}\right\|_{L} \leq 1} \mathbb{E}_{x \sim p_{r}}\left[D_{\theta}(\boldsymbol{x})\right]-\mathbb{E}_{\boldsymbol{x} \sim p_{g}}\left[D_{\theta}(\boldsymbol{x})\right]
因此求解W(pr,pg)W\left(p_{r}, p_{g}\right)的問題可以轉化爲:
maxθExpr[Dθ(x)]Expg[Dθ(x)]\max _{\theta} \mathbb{E}_{x \sim p_{r}}\left[D_{\theta}(\boldsymbol{x})\right]-\mathbb{E}_{\boldsymbol{x} \sim p_{g}}\left[D_{\theta}(\boldsymbol{x})\right]
這就是判別器 D的優化目標。判別網絡函數Dθ(x)D_{\theta}(x)需要滿足 1 階-Lipschitz 約束:
x^D(x^)I\nabla_{\widehat{x}} D(\hat{\boldsymbol{x}}) \leq I
在 WGAN-GP 論文中,作者提出採用增加梯度懲罰項(Gradient Penalty)方法來迫使判別網絡滿足 1 階-Lipschitz 函數約束,同時作者發現將梯度值約束在 1 周圍時工程效果更好,因此梯度懲罰項定義爲:
GPEx^Px^[(x^D(x^)21)2]\mathrm{GP} \triangleq \mathbb{E}_{\hat{x} \sim P_{\hat{x}}}\left[\left(\left\|\nabla_{\hat{x}} D(\hat{x})\right\|_{2}-1\right)^{2}\right]
因此 WGAN 的判別器 D 的訓練目標爲:
maxθL(G,D)=Exrpr[D(xr)]Exfpg[D(xf)]EMLRλEx^Px~[(x^D(x^)21)2]GPZ,UH\max _{\theta} \mathcal{L}(G, D)=\underbrace{\mathbb{E}_{x_{r} \sim p_{r}}\left[D\left(\boldsymbol{x}_{r}\right)\right]-\mathbb{E}_{x_{f} \sim p_{g}}\left[D\left(\boldsymbol{x}_{f}\right)\right]}_{E M \mathscr{L}_{\mathbb{R}}} \underbrace{-\lambda \mathbb{E}_{\hat{\boldsymbol{x}}_{\mathcal{P}_{\tilde{\boldsymbol{x}}}}}\left[\left(\left\|\nabla_{\hat{\boldsymbol{x}}} D(\hat{\boldsymbol{x}})\right\|_{2}-1\right)^{2}\right]}_{G P \mathbb{Z}, \mathbb{U} \mathscr{H}}
其中𝒙̂來自於𝒙𝑟與𝒙𝑟的線性差值:
x^=txr+(1t)xf,t[0,1]\hat{x}=t \boldsymbol{x}_{r}+(1-t) \boldsymbol{x}_{f}, t \in[0,1]
判別器 D 的目標是最小化上述的誤差ℒ(𝐺,𝐷),即迫使生成器 G 的分佈pgp_{g}與真實分佈prp_{r}之間 EM 距離Exrpr[D(xr)]Exfpg[D(xf)]\mathbb{E}_{x_{r} \sim p_{r}}\left[D\left(\boldsymbol{x}_{r}\right)\right]-\mathbb{E}_{\boldsymbol{x}_{f} \sim p_{g}}\left[D\left(\boldsymbol{x}_{f}\right)\right]項儘可能大,x^D(x^)2\left\|\nabla_{\widehat{x}} D(\widehat{\boldsymbol{x}})\right\|_{2}逼近於 1。

WGAN 的生成器 G 的訓練目標爲:
minϕL(G,D)=Exrpr[D(xr)]Exfpg[D(xf)]EMEE\min _{\phi} \mathcal{L}(G, D)=\underbrace{\mathbb{E}_{x_{r} \sim p_{r}}\left[D\left(\boldsymbol{x}_{r}\right)\right]-\mathbb{E}_{\boldsymbol{x}_{f} \sim p_{g}}\left[D\left(\boldsymbol{x}_{f}\right)\right]}_{E M \mathbb{E}_{\mathbb{E}}}

即使得生成器的分佈pgp_{g}與真實分佈prp_{r}之間的 EM 距離越小越好。考慮到Exrpr[D(xr)]\mathbb{E}_{x_{r} \sim p_{r}}\left[D\left(\boldsymbol{x}_{r}\right)\right]一項與生成器無關,因此生成器的訓練目標簡寫爲:
minϕL(G,D)=Exfpg[D(xf)]=Ezpz()[D(G(z))]\begin{array}{c} \min _{\phi} \mathcal{L}(G, D)=-\mathbb{E}_{x_{f} \sim p_{g}}\left[D\left(x_{f}\right)\right] \\ =-\mathbb{E}_{z \sim p_{z}(\cdot)}[D(G(z))] \end{array}
從實現來看,判別網絡 D 的輸出不需要添加 Sigmoid 激活函數,這是因爲原始版本的判別器的功能是作爲二分類網絡,添加 Sigmoid 函數獲得類別的概率;而 WGAN 中判別器作爲 EM 距離的度量網絡,其目標是衡量生成網絡的分佈𝑝𝑔和真實分佈𝑝𝑟之間的 EM 距離,屬於實數空間,因此不需要添加 Sigmoid 激活函數。在誤差函數計算時,WGAN 也沒有 log 函數存在。在訓練 WGAN 時,WGAN 作者推薦使用 RMSProp 或 SGD 等不帶動量的優化器。

WGAN 從理論層面發現了原始 GAN 容易出現訓練不穩定的原因,並給出了一種新的距離度量標準和工程實現解決方案,取得了較好的效果。WGAN 還在一定程度上緩解了模式崩塌的問題,使用 WGAN 的模型不容易出現模式崩塌的現象。需要注意的是,WGAN一般並不能提升模型的生成效果,僅僅是保證了模型訓練的穩定性。當然,保證模型能夠穩定地訓練也是取得良好效果的前提。如圖13.21 所示,原始版本的 DCGAN 在不使用BN 層等設定時出現了訓練不穩定的現象,在同樣設定下,使用 WGAN 來訓練判別器可以避免此現象,如圖 13.22 所示。
在這裏插入圖片描述
在這裏插入圖片描述

13.8 WGAN-GP 實戰

WGAN-GP 模型可以在原來 GAN 代碼實現的基礎上僅做少量修改。WGAN-GP 模型的判別器 D 的輸出不再是樣本類別的概率,輸出不需要加 Sigmoid 激活函數。同時添加梯度懲罰項,實現如下:

def gradient_penalty(discriminator, batch_x, fake_image):
    # 梯度懲罰項計算函數
    batchsz = batch_x.shape[0]
    # 每個樣本均隨機採樣 t,用於插值
    t = tf.random.uniform([batchsz, 1, 1, 1])
    # 自動擴展爲 x 的形狀,[b, 1, 1, 1] => [b, h, w, c]
    t = tf.broadcast_to(t, batch_x.shape)
    # 在真假圖片之間做線性插值
    interplate = t * batch_x + (1 - t) * fake_image
   # 在梯度環境中計算 D 對插值樣本的梯度
   with tf.GradientTape() as tape:
       tape.watch([interplate]) # 加入梯度觀察列表
       d_interplote_logits = discriminator(interplate)
   grads = tape.gradient(d_interplote_logits, interplate)
   # 計算每個樣本的梯度的範數:[b, h, w, c] => [b, -1]
   grads = tf.reshape(grads, [grads.shape[0], -1])
   gp = tf.norm(grads, axis=1) #[b]
   # 計算梯度懲罰項
   gp = tf.reduce_mean( (gp-1.)**2 )
   return gp

WGAN 判別器的損失函數計算與 GAN 不一樣,WGAN 是直接最大化真實樣本的輸出值,最小化生成樣本的輸出值,並沒有交叉熵計算的過程。代碼實現如下:

def d_loss_fn(generator, discriminator, batch_z, batch_x, is_training):
    # 計算 D 的損失函數
    fake_image = generator(batch_z, is_training) # 假樣本
    d_fake_logits = discriminator(fake_image, is_training) # 假樣本的輸出
    d_real_logits = discriminator(batch_x, is_training) # 真樣本的輸出
    # 計算梯度懲罰項
    gp = gradient_penalty(discriminator, batch_x, fake_image)
    # WGAN-GP D 損失函數的定義,這裏並不是計算交叉熵,而是直接最大化正樣本的輸出
    # 最小化假樣本的輸出和梯度懲罰項
    loss = tf.reduce_mean(d_fake_logits) - tf.reduce_mean(d_real_logits) + 10. * gp
    return loss, gp

WGAN 生成器 G 的損失函數是隻需要最大化生成樣本在判別器 D 的輸出值即可,同樣沒有交叉熵的計算步驟。代碼實現如下:

def g_loss_fn(generator, discriminator, batch_z, is_training):
   # 生成器的損失函數
   fake_image = generator(batch_z, is_training)
   d_fake_logits = discriminator(fake_image, is_training)
   # WGAN-GP G 損失函數,最大化假樣本的輸出值
   loss = - tf.reduce_mean(d_fake_logits)
   return loss

WGAN 的主訓練邏輯基本相同,與原始的 GAN 相比,判別器 D 的作用是作爲一個EM 距離的計量器存在,因此判別器越準確,對生成器越有利,可以在訓練一個 Step 時訓練判別器 D 多次,訓練生成器 G 一次,從而獲得較爲精準的 EM 距離估計。

參考文獻

1.https://github.com/chenyuntc/pytorch-book
2.那麼多GAN哪個好?谷歌大腦潑來冷水:都和原版差不多 | 論文

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章