生成對抗網絡(GAN)之 Basic Theory 學習筆記
前言:最近學習了李宏毅生成對抗網絡篇(2018年)的視頻(視頻地址:李宏毅對抗生成網絡(GAN)國語教程(2018)),因爲截止今天(3.23),2020版還未講到生成對抗網絡,因此選擇18年。本次學習筆記主要爲Basic Theory部分,主要講解GAN的數學原理。
GAN又稱生成對抗網絡,是由Ian Goodfellow等人在2014年提出的一種訓練策略,論文地址https://arxiv.org/abs/1406.2661,也稱爲對抗學習(Adversarial Learning),其主要由兩個部分組成,一個是生成器(或者說是採樣器)(Generator),另一個是判別器(Discriminator)。生成器主要負責從低維度簡單的分佈(例如正態分佈、均勻分佈等)隨機採樣,並映射到一個複雜高維度空間中,判別器的目標則是根據生成器生成的分佈和真實世界中採樣的樣本進行區分,儘可能的給真實世界的樣本高分,給生成器生成的假數據低分。一般來說,判別器可以當做一個0-1分類器。
1、GAN基本實現原理
GAN的生成原理如下圖所示:
結合圖,給出一個形式化的描述。給定一個低維度隨機分佈(例如Normal Distribution)作爲Noise,從中隨機採樣一個樣本 ,生成器首先通過計算將其生成高維度空間中的樣本,例如如果是圖像生成任務,則生成器的輸出 表示一張圖像。然後由判別器獲取到生成器生成的圖像 以及從真實世界中隨機採樣的真實圖像 ,判別器分別爲其進行打分 ,,目標則是使得 ,。生成器與判別器是一組相互制衡的Object,判別器儘可能的來區分哪些樣本是來自真實世界,哪些來自生成器製造的假數據;而生成器則是儘可能生成一些看起來很像真實世界的數據來迷惑判別器。
一個比較直觀的例子就是小偷與警察,這在許多博客中也有所提及。初始化時候,小偷的偷盜水平很差,警察的辦案能力也不高,但始終有一個條件就是警察的水平總會比小偷高一點,因此警察總會抓住一些水平很低的小偷。小偷爲了生存則不斷提高自己的偷盜水平和反偵察能力,警察也發現許多小偷的水平太高了來提升自己的辦案水平。如此進行下去。當然我們並不希望小偷的水平能夠達到警察都分辨不出來的底部,但是對於GAN來講,最終的結果則是我們更希望得到一個強大的生成器,因爲它能夠在判別器不斷的指導下來生成非常真實的東西,以至於連判別器都沒有辦法區分它們。
爲了能夠形象的描述,我們假設藍色曲線表示的是生成器生成的複雜的高維度分佈,綠色表示實際真實分佈,紅色則表示判別器的判別分佈。判別器目標就是在生成器生成的分佈部分給與低分,在真實分佈部分給與高分,而隨着不斷地迭代,生成樣本會在真實樣本附近來回的震盪,其會經過所有可能會使得判別器給分非常高但並非是真是樣本分佈的部分,最終理想狀態是右下角的圖,兩者分佈完全吻合。
2、GAN數學基礎
通過對GAN的理解,我們會產生疑問,生成器與判別器如何去衡量它們的好壞?首先我們從統計學角度分析。
我們知道,生成器輸入的是從低維度分佈中隨機採樣的噪聲,也就是先驗概率分佈,例如我們可以選擇正態分佈。而對於圖像等這一類數據往往是位於高緯度空間,且真正有價值、可讀的圖像只是其中一小部分,如下圖所示。綠色的部分就是生成器生成的高維度空間的分佈,最左側淺藍色則是潛在未知的真實分佈,我們更希望能夠讓這兩個分佈距離越小,即尋找一個最優的判別器以滿足 。但事實上,我們無法直接去計算兩者。
2.1、最大似然估計與KL散度
通過分析,GAN可以被認爲是衡量生成的分佈與真實分佈的距離,我們希望這個距離儘可能的小。在統計學中,我們是已知先驗分佈 ,從中隨機採樣一組樣本 ,我們可以根據這些採樣來對高維度分佈進行最大似然估計,即有一組參數
而上式可以再加上一個無關項 ,也就是說:
因此說,如果能夠找到一個參數 使得生成器生成的樣本與真實樣本的分佈KL散度最小,這組參數就是我們所學習的目標。但是事實上KL散度是不對稱的,並不能直接作爲GAN學習的目標。
2.2、GAN目標函數與JS散度
在上面我們提到GAN的目標是儘可能的讓判別器分辨出真假數據,生成器則是儘可能欺騙判別器,在Ian Goodfellow的論文中,給出了比較清晰的訓練目標函數,如下所示:
這是一個被稱爲min max遊戲的任務,也符合GAN的訓練機制:先固定生成器 ,尋找當前最優的 能夠使得 ,其次固定判別器 ,尋找能夠使得 。而事實上,這個min max遊戲本質是最小化 和 的JS散度,推導如下圖所示:
首先我們固定生成器,求最大化的 ,此時 可以看做是隻與 有關的一元函數,我們用積分來描述期望:
這裏使用一次積分換元,將的積分換爲對的積分。因爲此時
令無關變量,,則 ,求導後得到極值點爲 ,即
代入到 後,通過簡單的變換,可以轉換爲JS散度的形式,如下圖所示:
所以說,min max遊戲本質上是最小化JS散度,即
用圖示來描述這個過程:假設有三個不同的生成器 ,其對應判別器生成的曲線爲藍色線條,而線條上的點到底邊軸的距離即爲 ,通過上面的公式推導,結合圖我們很快理解,首先是找到藍色線條的最大值點,這也每個生成器都對應一個最大值點;其次從所有生成器中尋找一個最小的最大值點。圖中對應的就是 。
2.3、GAN算法
GAN算法如圖所示:
首先分別從真實數據和噪聲中隨機採樣一組,,其中 爲batch_size。先固定生成器訓練判別器。得到每個生成的數據 ,其中 。根據生成數據和真實數據的採樣,訓練判別器,目標最大化 ,可採用梯度上升更新參數。其次固定判別器,訓練生成器,最小化,梯度下降法更新參數。
需要值得注意的是,雖然訓練判別器和生成器使用的目標函數是一樣,但一個是最大化,一個是最小化。另外在訓練生成器時,可以簡化目標函數爲第二項 ,因爲第一項相對於是一個常數項。使用這個作爲目標函數的GAN被命名爲MMGAN。還有,原文使用了另一個目標函數來優化生成器,如下圖:
也就是NSGAN,目標函數爲 其相比於MMGAN,其能夠保持梯度的方向是不變,但梯度值會比較大,方便計算。
3、特別說明
GAN在訓練過程中有幾點需要注意:
(1)GAN擬合速度比較慢,因爲對於高維度空間的分佈,GAN經常會生成一些肉眼無法理解的內容,而需要迭代非常多次才能達到比較穩定的範圍內;
(2)判別器需要保證儘可能的或接近收斂,而生成器不能訓練太強。可以想象,如果生成器訓練的很強,而判別器還沒有達到一個較好的結果,此時判別器就無法判別出誰是真實數據,誰是假數據,此時可能導致訓練終止了但生成器生成的樣本還是很糟糕。通常情況下,在每一次對抗過程中,判別器儘可能訓練多次到達接近收斂,生成器訓練1-3次。
(3)在實際對抗訓練中,有一種策略可以實現簡單的對抗模式。我們假設GAN用於噪聲分類,即判斷給定一個一組樣本中,哪些是positive,哪些是negative。可以先使用生成器生成一組序列,表示對應每個樣本是噪聲的概率,換句話說就是生成一組它認爲是噪聲negative的樣本集合。然後使用判別器去判別這些集合是不是噪聲。判別器會從真實數據(positive)採樣一部分數據,但給它們標記爲negative,而生成器生成的數據標記爲positive,這樣如果判別器無法判別出究竟是哪個是positive,哪個是negative時(也就是判別器性能下降了),生成器就能夠比較準確的找出噪聲。
(4)GAN還有諸多變形,本文只是GAN的基本模型。例如生成器可以使用卷積神經網絡來提取圖像特徵,使用RNN來提取文本類特徵等。判別器因爲本質上是一個二分類器,在深度學習中,只要保證有一定深度,也可以使用CNN、RNN等編碼器。另外,也有將自編碼器結合到GAN中,例如VAEGAN。