GAN生成對抗網絡:數學原理

1. 極大似然估計

GAN用到了極大似然估計(MLE),因此我們對MLE作簡單介紹。

MLE的目標是從樣本數據中估計出真實的數據分佈情況,所用的方法是最大化樣本數據在估計出的模型上的出現概率,也即選定使得樣本數據出現的概率最大的模型,作爲真實的數據分佈。

將真實模型用參數θ\theta表示,則在模型θ\theta下,樣本數據的出現概率(likelihood)是(1)i=1mpmodel(xi;θ)\prod_{i=1}^mp_{model}(x_i; \theta) \tag{1}

其中xix_i表示樣本中的第ii個數據。

最大化(1)式的概率,求得滿足條件的θ\theta
θ=argmaxθi=1mpmodel(xi;θ)=argmaxθi=1mlogpmodel(xi;θ)\begin{aligned} \theta^* & = \arg\max_\theta\prod_{i=1}^mp_{model}(x_i; \theta) \\ &= \arg\max_\theta\sum_{i=1}^m\log p_{model}(x_i; \theta) \\ \end{aligned}

還可以使用KL散度來代表MLE方法:
θ=argminθDKL(pdata(x)pmodel(x;θ)=argminθ{i=1mpdata(xi)logpdata(xi)i=1mpdata(xi)logpmodel(xi;θ)}=argminθi=1mpdata(xi)logpmodel(xi;θ)=argmaxθi=1mpdata(xi)logpmodel(xi;θ)\begin{aligned} \theta^*&=\arg\min_\theta D_{KL}(p_{data}(x) || p_{model}(x;\theta)\\ & = \arg\min_\theta\left\{ \sum_{i=1}^mp_{data}(x_i)\log p_{data}(x_i) - \sum_{i=1}^mp_{data}(x_i)\log p_{model}(x_i;\theta) \right\}\\ & = -\arg\min_\theta\sum_{i=1}^mp_{data}(x_i)\log p_{model}(x_i;\theta) \\ & = \arg\max_\theta\sum_{i=1}^mp_{data}(x_i)\log p_{model}(x_i;\theta) \end{aligned}

在實際上,我們無法得到數據的真實分佈pdatap_{data},但是可以從mm個數據的樣本中近似得到一個估計p^data\hat{p}_{data}

爲了便於理解KL散度,我們在下面對其進行簡要介紹。

2. 相對熵,KL散度

兩個概率分佈PPQQ的KL散度定義如下:
DKL(PQ)=iP(i)logP(i)Q(i)D_{KL}(P||Q)=\sum_iP(i)\log{\frac{P(i)}{Q(i)}}

性質
DKL(PQ)0D_{KL}(P||Q)\ge0

當且僅當P=QP=Q時,等號成立。(證明過程借用吉布斯不等式ipilogpiipilogqi\sum_ip_i\log p_i\ge\sum_ip_i\log q_i,證明吉布斯不等式會用到關係logxx1\log x \le x - 1

KL散度反映了兩個分佈PPQQ的相似情況,KL散度越小,兩個分佈越相似。

KL散度是不對稱的:
DKL(PQ)DKL(QP)D_{KL}(P||Q) \quad\neq D_{KL}(Q||P)

3. KL散度與交叉熵的關係

神經網絡中常常使用交叉熵作爲損失函數:
L=iyiloghiL = -\sum_i y_i\log h_i

其中yiy_i是實際的標籤值,hih_i是網絡的輸出值。

我們將yyhh的KL散度展開,得到:
DKL(yh)=iyilogyihi=iyilogyiiyiloghi=iyilogyi+L=Constant+L\begin{aligned} D_{KL}(y||h) & = \sum_iy_i\log{\frac{y_i}{h_i}}\\ & = \sum_iy_i\log y_i - \sum_iy_i\log h_i\\ & = \sum_iy_i\log y_i + L\\ &= Constant + L \end{aligned}

因此,最小化KL散度,等價於最小化損失函數LL。也即交叉熵損失函數反應的是網絡輸出結果和樣本實際標籤結果的KL散度的大小,交叉熵越小,KL散度也越小,網絡的輸出結果越接近實際值

4. JS散度

對於兩個分佈PPQQ,JS散度是:
DJS(PQ)=12DKL(PP+Q2)+12DKL(QP+Q2)D_{JS}(P||Q) = \frac{1}{2}D_{KL}(P||\frac{P+Q}{2}) + \frac{1}{2}D_{KL}(Q||\frac{P+Q}{2})

JS散度是對稱的,並且有界[0,log2][0, \log2]

5. GAN 框架

生成器,生成與訓練集數據相同分佈的樣本;判別器,檢查生成器生成的樣本是真的還是假的。
The generator is trained to fool the discriminator.
在這裏插入圖片描述

判別器的損失函數

判別器的損失函數爲:
(2)J(D)(θ(D),θ(G))=12ExpdatalogD(x)12Ezpmodellog(1D(G(z)))J^{(D)}(\theta^{(D)}, \theta^{(G)})= -\frac{1}{2}\mathbb{E}_{x\sim p_{data}}\log D(x) - \frac{1}{2}\mathbb{E}_{z\sim p_{model}}\log (1-D(G(z)))\tag{2}

上式其實就是一個交叉熵損失函數。GAN的判別器在訓練的過程中,數據集包含兩個部分,一部分是訓練集的樣本xx,對應的標籤y=1y=1,一部分是生成器生成的數據G(z)G(z),對應的標籤y=0y=0,因此判別器的訓練集可以看做X={x,G(z)},Y={1,0}X=\{x, G(z)\}, Y=\{1, 0\}

訓練集樣本是XX,標籤是YY,網絡輸出是HH,則交叉熵損失函數爲:
(3)J=1mi=1m{YilogHi(1Yi)log(1Hi)}J = \frac{1}{m} \sum_{i=1}^m\{-Y_i\log H_i - (1-Y_i)\log(1-H_i)\}\tag{3}

與式(2)作比較,前一項的logH\log H等價於式(2)中的logD(x)\log D(x),後一項的log(1Hi)\log(1-H_i)等價於式(2)中的log(1D(G(z)))\log(1-D(G(z)))。將xx看做包含了真實樣本和生成器生成的數據G(z)G(z)的新的訓練集,則判別器的損失函數可以重新寫作:
(4)J(D)(θ(D),θ(G))=12ExpdatalogD(x)12Expmodellog(1D(x))=12ipdata(xi)logD(xi)12ipmodel(xi)log(1D(xi))\begin{aligned} J^{(D)}(\theta^{(D)}, \theta^{(G)}) &= -\frac{1}{2}\mathbb{E}_{x\sim p_{data}}\log D(x) - \frac{1}{2}\mathbb{E}_{x\sim p_{model}}\log (1-D(x))\\ &= -\frac{1}{2} \sum_ip_{data}(x_i)\log D(x_i) -\frac{1}{2}\sum_i p_{model}(x_i) \log (1-D(x_i)) \end{aligned}\tag{4}

對上式關於D(x)D(x)求導,並令導數爲0,得到:
D(x)=pdata(x)pdata(x)+pmodel(x)D^*(x) = \frac{p_{data}(x)}{p_{data}(x)+p_{model}(x)}

生成器的損失函數

J(G)=J(D)J^{(G)}=-J^{(D)},則
J(G)(θ(D),θ(G))=12ExpdatalogD(x)+12Ezpmodellog(1D(G(z)))=Constant+12Ezpmodellog(1D(G(z)))\begin{aligned} J^{(G)}(\theta^{(D)}, \theta^{(G)}) &= \frac{1}{2}\mathbb{E}_{x\sim p_{data}}\log D(x) + \frac{1}{2}\mathbb{E}_{z\sim p_{model}}\log (1-D(G(z)))\\ & = Constant + \frac{1}{2}\mathbb{E}_{z\sim p_{model}}\log (1-D(G(z))) \end{aligned}

生成器沒有直接接受任何的訓練集數據,訓練集數據的信息是通過判別器學習後傳遞過來的。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章