生成對抗網絡 GAN 的數學原理

摘要

本文從概率分佈及參數估計說起, 通過介紹極大似然估計, KL 散度, JS 散度, 詳細的介紹了 GAN 生成對抗網絡的數學原理.

相關

系列文章索引 :

https://blog.csdn.net/oBrightLamp/article/details/85067981

正文

無論是黑白圖片或彩色圖片, 都是使用 0 ~ 255 的數值表示像素. 將所有的像素值除以 255 我們就可以將一張圖片轉化爲 0 ~ 1 的概率分佈, 而且這種轉化是可逆的, 乘以 255 就可以還原.

從某種意義上來講, GAN 圖片生成任務就是生成概率分佈. 因此, 我們有必要結合概率分佈來理解 GAN 生成對抗網絡的原理.

1. 概率分佈及參數估計

假設一個抽獎盒子裏有45個球, 其編號是 1~9 共9個數字. 每個編號的球擁有的數量是:

編號 1 2 3 4 5 6 7 8 9
數量 2 4 6 8 9 7 5 3 1
佔比 0.044 0.088 0.133 0.178 0.200 0.156 0.111 0.066 0.022

佔比是指用每個編號的數量除以所有編號的總數量, 在數理統計學中, 在不引起誤會的情況下, 這裏的佔比也可以被稱爲 概率/頻率.

使用向量 qq 表示上述的概率分佈 :

q=(2,4,6,8,9,7,5,3,1)/45  =(0.044,0.088,0.133,0.178,0.200,0.156,0.111,0.066,0.022) q = (2,4,6,8,9,7,5,3,1)/45 \;\\ =(0.044, 0.088, 0.133, 0.178, 0.200, 0.156, 0.111, 0.066, 0.022)

將上述分佈使用圖像繪製如下 :

在這裏插入圖片描述

現在我們希望構建一個函數 p=p(x;θ)p = p(x;\theta), 以 xx 爲編號作爲輸入數據, 輸出編號 xx 的概率. θ\theta 是參與構建這個函數的參數, 一經選定就不再變化.

假設上述概率分佈服從二次拋物線函數 :
p=p(x;θ)=θ1(x+θ2)2+θ3  x=(1,2,3,4,5,6,7,8,9) p=p(x;\theta)=\theta_1 (x+\theta_2)^2+\theta_3\\ \;\\ x = (1,2,3,4,5,6,7,8,9)

使用 L2 誤差作爲評價擬合效果的損失函數, 總誤差值爲 error (標量 e) :
e=i=19(piqi)2 e = \sum_{i=1}^{9}(p_i-q_i)^2
我們希望求得一個 θ\theta^*, 使得 ee 的值越小越好, 使用數學公式表達是這樣子的 :
θ=argminθ(e) \theta^* = \underset{\theta}{argmin}(e)
argmin 是 argument minimum 的縮寫.

如何求 θ\theta^* 不是本文的重點, 這是生成對抗網絡的任務. 爲了幫助理解, 取其中一個可能的數值作爲示例 :

θ=(θ1,θ2,θ3)=(0.01,5.0,0.2)  p=p(x;θ)=0.01(x5.0)2+0.2 \theta^* = (\theta_1,\theta_2,\theta_3)=(-0.01,-5.0,0.2)\\ \;\\ p=p(x;\theta)=-0.01 (x-5.0)^2+0.2
繪製函數圖像如下 :

在這裏插入圖片描述

在生成對抗網絡中, 本例的估計函數 p(x;θ)p(x;\theta) 相當於生成模型 (generator), 損失函數相當於鑑別模型 (discriminator).

2. 極大似然估計

在上例中, 我們很幸運的知道了所有可能的概率分佈, 並讓求解最優的概率分佈估計函數 p(x;θ)p(x;\theta) 成爲可能.

如果上例的抽獎盒子 (簡稱樣本) 中的 45 個球是從更大的抽獎池 (簡稱總體) 中選擇出來的, 而我們不知道抽獎池中所有球的數量及其編號. 那麼, 我們如何根據現有的 45 個球來估計抽獎池的概率分佈呢? 我們當然可以直接用上例求得的樣本估計函數來代表抽獎池的概率分佈, 但本例會介紹一種更常用的估計方法.

假設 p(x)=p(x;θ)p(x)=p(x;\theta) 是總體的概率分佈函數. 則編號 x=(x1,x2,x3, ,xn)x = (x_1,x_2,x_3,\cdots,x_n) 出現的概率爲 :
p=p(x1),p(x2),p(x3), ,p(xn) p = p(x_1),p(x_2),p(x_3),\cdots,p(x_n)
在本例中, n=9n = 9, 即共 9 個編號.

d=(d1,d2,d3,d3 ,dm)d=(d_1,d_2,d_3,d_3\cdots,d_m) 是所有的抽樣的編號. 在本例中, m=45m = 45, 即樣本中共有 45 個抽樣. 假設所有的樣本和抽樣都是獨立的, 則樣本出現的概率爲 :
ρ=p(d1)×p(d2)×p(d3)××p(dm)=i=1mp(di) \rho= p(d_1)\times p(d_2)\times p(d_3)\times\cdots\times p(d_m)=\prod_{i=1}^{m}p(d_i)
p(x)=p(x;θ)p(x)=p(x;\theta) 的函數結構是人爲按經驗選取的, 比如線性函數, 多元二次函數, 更復雜的非線性函數等, 一經選取則不再改變. 現在我們需要求解一個參數集 θ\theta^*, 使得 ρ\rho 的值越大越好. 即
θ=argmaxθ(ρ)=argmaxθi=1mp(di;θ) \theta^* = \underset{\theta}{argmax}(\rho)=\underset{\theta}{argmax}\prod_{i=1}^{m}p(d_i;\theta)
argmax 是 argument maximum 的縮寫.

通俗來講, 因爲樣本是實際已發生的事實, 在函數結構已確定的情況下, 我們需要儘量優化參數, 使得樣本的理論估計概率越大越好.

這裏有一個前提, 就是人爲選定的函數結構應當能夠有效評估樣本分佈. 反之, 如果使用線性函數去擬合正態概率分佈 (normal distribution), 則無論如何選擇參數都無法得到滿意的效果.

連乘運算不方便, 將之改爲求和運算. 由於 loglog 對數函數的單調性, 上面的式子等價於 :
θ=argmaxθ  logi=1mp(di;θ)=argmaxθi=1mlog  p(di;θ) \theta^* =\underset{\theta}{argmax}\;log\prod_{i=1}^{m}p(d_i;\theta)=\underset{\theta}{argmax}\sum_{i=1}^{m}log\;p(d_i;\theta)
設樣本分佈爲 q(x)q(x), 對於給定樣本, 這個分佈是已知的, 可以通過統計抽樣的計算得出. 將上式轉化成期望公式 :
θ=argmaxθi=1mlog  p(di;θ)=argmaxθi=1nq(xi)  log  p(xi;θ) \theta^* =\underset{\theta}{argmax}\sum_{i=1}^{m}log\;p(d_i;\theta) =\underset{\theta}{argmax}\sum_{i=1}^{n}q(x_i)\;log\;p(x_i;\theta)
注意上式中的兩個求和符號, mm 變成了 nn. 在大多數情況下, 編號數量會比抽樣數量少, 轉爲期望公式可以顯著減少計算量.

在一些教材中, 上式的寫法是:
θ=argmaxθ  Exq(x)log  p(x;θ)=argmaxθq(x)  log  p(x;θ)dx \theta^*=\underset{\theta}{argmax}\; E_{x-q(x)}log\;p(x;\theta) =\underset{\theta}{argmax}\int q(x)\;log\;p(x;\theta)dx
Exq(x)E_{x-q(x)} 表示按 q(x)q(x) 的分佈對 xx 求期望. 因爲積分表達式比較簡潔, 書寫方便, 下文開始將主要使用積分表達式.

以上就是極大似然估計(Maximum Likelihood Estimation) 的理論和推導過程. 和上例的參數估計方法相比, 極大似然估計的方法, 因爲無需設計損失函數, 降低了模型的複雜度, 擴大了適用範圍.

本例中的估計函數 p(x;θ)p(x;\theta) 相當於生成對抗網絡的生成模型, 樣本分佈 q(x)q(x) 相當於訓練數據.

3. KL散度

結合上例, 在樣本已知的情況下, q(x)q(x) 是一個已知且確定的分佈. 則 q(x)  log  q(x)dx\int q(x)\;log\;q(x)dx 是一個常數項, 不影響 θ\theta^* 求解的結果.

θ=argmaxθ(q(x)  log  p(x;θ)dxq(x)  log  q(x)dx)  =argmaxθq(x)  (log  p(x;θ)log  q(x))dx  =argmaxθq(x)  log  p(x;θ)q(x)dx \theta^*=\underset{\theta}{argmax}(\int q(x)\;log\;p(x;\theta)dx-\int q(x)\;log\;q(x)dx)\\ \;\\ =\underset{\theta}{argmax}\int q(x)\;(log\;p(x;\theta)-log\;q(x))dx\\ \;\\ =\underset{\theta}{argmax}\int q(x)\;log\;\frac{p(x;\theta)}{q(x)}dx\\
也可以寫成這樣 :
θ=argminθ(q(x)  log  p(x;θ)dx+q(x)  log  q(x)dx)  =argminθq(x)  log  q(x)p(x;θ)dx \theta^*=\underset{\theta}{argmin}(-\int q(x)\;log\;p(x;\theta)dx+\int q(x)\;log\;q(x)dx)\\ \;\\ =\underset{\theta}{argmin}\int q(x)\;log\;\frac{q(x)}{p(x;\theta)}dx\\

KL 散度 ( Kullback–Leibler divergence ) 是一種衡量兩個概率分佈的匹配程度的指標, 兩個分佈差異越大, KL散度越大. 它還有很多名字, 比如: relative entropy, relative information.

其定義如下 :
DKL(qp)=q(x)  log  q(x)p(x)dx D_{KL}(q||p)=\int q(x)\;log\;\frac{q(x)}{p(x)}dx
p(x)q(x)p(x)\equiv q(x) 時取得最小值 DKL(qp)=0D_{KL}(q||p) = 0.

我們可以將上面的公式簡化爲 :
θ=argminθ  DKL(qp(x;θ)) \theta^*=\underset{\theta}{argmin}\;D_{KL}(q||p(x;\theta))

4. JS 散度

KL 散度是非對稱的,即 DKL(qp)D_{KL}(q||p) 不一定等於 DKL(pq)D_{KL}(p||q). 爲了解決這個問題, 需要引入 JS 散度.

JS 散度 ( Jensen–Shannon divergence ) 的定義如下 :
m=12(p+q)  DJS=12DKL(pm)+12DKL(qm) m =\frac{1}{2}(p + q) \\ \;\\ D_{JS}=\frac{1}{2}D_{KL}(p||m) + \frac{1}{2}D_{KL}(q||m)
JS 的值域是對稱的, 有界的, 範圍是 [0,1].

如果 p, q 完全相同, 則 JS = 0, 如果完全不相同, 則 JS = 1.

注意, KL 散度和 JS 散度作爲差異度量的時候, 有一個問題:

如果兩個分配 p, q 離得很遠, 完全沒有重疊的時候, 那麼 KL 散度值是沒有意義的, 而 JS 散度值是一個常數. 這在學習算法中是比較致命的, 這就意味這這一點的梯度爲0, 梯度消失了.

參考上例, 對 JS 散度進行反推:
DJS(qp)=12DKL(qm)+12DKL(pm)  =12q(x)  log  q(x)q(x)+p(x;θ)2dx+12p(x;θ)  log  p(x;θ)p(x;θ)+q(x)2dx  =12q(x)  log  2q(x)q(x)+p(x;θ)dx+12p(x;θ)  log  2p(x;θ)p(x;θ)+q(x)dx D_{JS}(q||p)=\frac{1}{2}D_{KL}(q||m)+\frac{1}{2}D_{KL}(p||m)\\ \;\\ =\frac{1}{2}\int q(x)\;log\;\frac{q(x)}{\frac{q(x)+p(x;\theta)}{2}}dx+ \frac{1}{2}\int p(x;\theta)\;log\;\frac{p(x;\theta)}{\frac{p(x;\theta)+q(x)}{2}}dx\\ \;\\ =\frac{1}{2}\int q(x)\;log\;\frac{2q(x)}{q(x)+p(x;\theta)}dx+ \frac{1}{2}\int p(x;\theta)\;log\;\frac{2p(x;\theta)}{p(x;\theta)+q(x)}dx
由於 :
q(x)  log  2q(x)q(x)+p(x;θ)dx  =q(x)  (log  q(x)q(x)+p(x;θ)+log2)dx  =q(x)  log  q(x)q(x)+p(x;θ)dx+q(x)(log2)dx  =q(x)  log  q(x)q(x)+p(x;θ)dx+log2 \int q(x)\;log\;\frac{2q(x)}{q(x)+p(x;\theta)}dx\\ \;\\ =\int q(x)\;(log\;\frac{q(x)}{q(x)+p(x;\theta)}+log2)dx\\ \;\\ =\int q(x)\;log\;\frac{q(x)}{q(x)+p(x;\theta)}dx+\int q(x)(log2)dx\\ \;\\ =\int q(x)\;log\;\frac{q(x)}{q(x)+p(x;\theta)}dx+log2
同理可得 :
DJS(qp)=12q(x)  log  q(x)q(x)+p(x;θ)dx+12p(x;θ)  log  p(x;θ)p(x;θ)+q(x)dx+log2 D_{JS}(q||p)=\frac{1}{2}\int q(x)\;log\;\frac{q(x)}{q(x)+p(x;\theta)}dx+ \frac{1}{2}\int p(x;\theta)\;log\;\frac{p(x;\theta)}{p(x;\theta)+q(x)}dx+log2
令 :
d(x;θ)=q(x)q(x)+p(x;θ) d(x;\theta)=\frac{q(x)}{q(x)+p(x;\theta)}
則 :
1d(x;θ)=p(x;θ)q(x)+p(x;θ) 1-d(x;\theta)=\frac{p(x;\theta)}{q(x)+p(x;\theta)}
即 :
DJS(qp)=12q(x)  log  d(x;θ)dx+12p(x;θ)  log  (1d(x;θ))dx+log2 D_{JS}(q||p)=\frac{1}{2}\int q(x)\;log\;d(x;\theta)dx+ \frac{1}{2}\int p(x;\theta)\;log\;(1-d(x;\theta))dx+log2
令 :
V(x;θ)=q(x)  log  d(x;θ)dx+p(x;θ)  log  (1d(x;θ))dx V(x;\theta) =\int q(x)\;log\;d(x;\theta)dx+ \int p(x;\theta)\;log\;(1-d(x;\theta))dx
則 :
DJS(qp)=12V(x;θ)+log2 D_{JS}(q||p)=\frac{1}{2}V(x;\theta)+log2
即 :
θ=argminθ  DJS(qp)=argminθ  V(x;θ) \theta^*=\underset{\theta}{argmin}\;D_{JS}(q||p)=\underset{\theta}{argmin}\;V(x;\theta)
此時, θ\theta^* 是令 p(x;θ)p(x;\theta)q(x)q(x) 差異最小的參數. 同樣亦可通過 V(x;θ)V(x;\theta) 求差異最大的參數.

5. JS 散度參數求解的兩步走迭代方法

從上例的討論我們知道, 我們需要求得一個參數 θ\theta^*, 使得
θ=argminθDJS(qp)=argminθV(x;θ) \theta^*=\underset{\theta}{argmin}D_{JS}(q||p)=\underset{\theta}{argmin} V(x;\theta)
然而, 因爲涉及多重嵌套和積分, 使用起來並不方便.

首先, 我們假設 p(x;θ)=pg(x)p(x;\theta) = p_g(x) 爲已知條件, 同時令 D=d(x;θ)D=d(x;\theta), 考慮這個式子:
W(x;θ)=q(x)  log  d(x;θ)dx+p(x;θ)  log  (1d(x;θ))  W(x;D)=q(x)  log  D+pg(x)  log  (1D)  V(x;θ)=V(x;D)=W(x;D)dx W(x;\theta)=q(x)\;log\;d(x;\theta)dx+ p(x;\theta)\;log\;(1-d(x;\theta))\\ \;\\ W(x;D)=q(x)\;log\;D+p_g(x)\;log\;(1-D)\\ \;\\ V(x;\theta)=V(x;D)=\int W(x;D)dx

xx 已經確定的情況下, 我們關注 DD.
W=dWdD=q(x)1Dpg(x)11D  W=dWdD=q(x)1D2pg(x)1(1D)2 W'=\frac{dW}{dD}=q(x)\frac{1}{D}-p_g(x)\frac{1}{1-D}\\ \;\\ W''=\frac{dW'}{dD}=-q(x)\frac{1}{D^2}-p_g(x)\frac{1}{(1-D)^2}

因爲 W<0W'' < 0, 當 W=0W'=0 時, WW 取得極大值 :
W=q(x)1Dpg(x)11D=0  D=q(x)q(x)+pg(x) W'=q(x)\frac{1}{D}-p_g(x)\frac{1}{1-D}=0\\ \;\\ D = \frac{q(x)}{q(x)+p_g(x)}
因爲 :
D<q(x)q(x)+pg(x),    W>0  D>q(x)q(x)+pg(x),    W<0 D < \frac{q(x)}{q(x)+p_g(x)},\;\;W'>0\\ \;\\ D > \frac{q(x)}{q(x)+p_g(x)},\;\;W'<0
這表明, 當 DD 的函數按 W=0W'=0 取值時, WWxx 的每個取樣點均獲得最大值, 積分後的面積獲得最大值, 即 :
D=q(x)q(x)+pg(x)=argmaxDW(x;D)dx=argmaxDV(x;D) D^*=\frac{q(x)}{q(x)+p_g(x)}=\underset{D}{argmax}\int W(x;D)dx=\underset{D}{argmax}V(x;D)
maxD  V(x;D)=q(x)  log  D(x)dx+pg(x)  log  (1D(x))dx \underset{D}{max}\;V(x;D)=\int q(x)\;log\;D^*(x)dx+\int p_g(x)\;log\;(1-D^*(x))dx

在得到 V(x;D)V(x;D) 的最大值表達式後, 我們固定 DD^*, 接着對 p(x;θ)=pg(x)p(x;\theta) = p_g(x) 將這個最大值的按最小方向優化 :
V(x;θ;D)=q(x)  log  D(x)  dx+p(x;θ)  log  (1D(x))dx  θ=argminθ  V(x;θ;D) V(x;\theta;D^*)=\int q(x)\;log\;D^*(x)\;dx+\int p(x;\theta)\;log\;(1-D^*(x))dx\\ \;\\ \theta^*=\underset{\theta}{argmin}\;V(x;\theta^*;D^*)
由此, 通過兩步走的方法, 經過多次先後迭代求解 DD^*θ\theta^*, 我們可以逐漸得到一個趨近於 q(x)q(x)p(x;θ)p(x;\theta^*).

6. 生成對抗網絡

從上面的討論方法可知, 我們可以得到一個和 q(x)q(x) 非常接近的分佈函數 p(x;θ)p(x;\theta). 這個分佈函數的構建是爲了尋找已知樣本數據的內在規律.

然而我們往往並不關心這個分佈函數. 我們希望無中生有的構建一批數據 xx', 使得 p(x;θ)p(x';\theta) 趨近於 q(x)q(x).

我們設計一個輸出 xx' 的生成器 x=G(z;β)x'=G(z;\beta), 從隨機概率分佈中接收 zz 作爲輸入, xx' 的概率分佈爲 pg(x)p_g(x').

第一步, 我們固定 pg(x)p_g(x')DD^*.
V(x,x;D)=q(x)  log  D(x)  dx+pg(x)  log  (1D(x))dx  D=argmaxDV(x;D) V(x,x';D)=\int q(x)\;log\;D(x)\;dx+\int p_g(x')\;log\;(1-D(x'))dx\\ \;\\ D^*=\underset{D}{argmax}V(x;D)

第二步, 我們固定 DD^*pg(x;θ)p_g(x';\theta^*).
V(x,x,D;θ)=q(x)  log  D(x)  dx+pg(x;θ)  log  (1D(x))dx  θ=argminθ  V(x,D;θ) V(x,x',D^*;\theta)=\int q(x)\;log\;D^*(x)\;dx+\int p_g(x';\theta)\;log\;(1-D^*(x'))dx\\ \;\\ \theta^*=\underset{\theta}{argmin}\;V(x,D^*;\theta^*)

然後進行多次循環迭代, 使得 pg(x;θ)p_g(x';\theta^*) 趨近於 q(x)q(x).

讀者可以發現, 這裏求解過程和上例的是一樣, 只是輸入的數據並不一致.

在實際任務中, 我們並不關心 pg(x;θ)p_g(x';\theta), 僅關注生成器 x=G(z;β)x'=G(z;\beta) 的優化.

由此我們將算法改編如下 :

第一步, 我們固定 x=G(z;β)x'=G(z;\beta)DD^*.
V(x,z;D)=q(x)  log  D(x)  dx+q(z)  log  (1D(G(z)))dz  D=argmaxDV(x,z;D) V(x,z;D)=\int q(x)\;log\;D(x)\;dx+\int q(z)\;log\;(1-D(G(z)))dz\\ \;\\ D^*=\underset{D}{argmax}V(x,z;D)

第二步, 我們固定 DD^*G(z;β)G(z;\beta^*).
V(x,z,D;β)=q(x)  log  D(x)  dx+q(z)  log  (1D(G(z;β)))dz  β=argminβ  V(x,z,D;β) V(x,z,D^*;\beta)=\int q(x)\;log\;D^*(x)\;dx+\int q(z)\;log\;(1-D^*(G(z;\beta)))dz\\ \;\\ \beta^*=\underset{\beta}{argmin}\;V(x,z,D^*;\beta)

注意, 本例的兩個算法都沒有給出嚴格的收斂證明.

由於求解形式和上例的 JS 散度的參數求解算法非常一致, 我們可以期待這種算法能夠起到作用.

爲簡單起見, 我們記 :
V(G,D)=q(x)  log  D(x)  dx+q(z)  log  (1D(G(z)))dz  G=argminG  (  maxD  V(G,D)  ) V(G,D)=\int q(x)\;log\;D(x)\;dx+\int q(z)\;log\;(1-D(G(z)))dz\\ \;\\ G^*=\underset{G}{argmin}\;(\;\underset{D}{max}\;V(G,D)\;)
這就是 GAN 生成對抗網絡相關文獻中常見的求解表達式.

在 Ian J. Goodfellow 的論文 Generative Adversarial Networks 中, 作者先給出了 V(G,D)V(G,D) 的表達式, 然後再通過 JS 散度的理論來證明其收斂性. 有興趣的讀者可以參考閱讀.

本文認爲, 如果先介紹 JS 散度, 再進行反推, 可以更容易的理解 GAN 概念, 理解 GAN 爲什麼要用這麼複雜的損失函數.

7. 生成對抗網絡的工程實踐

在工程實踐中, 我們遇到的一般是離散的數據. 我們可以使用隨機採樣的方法來逼近期望值.

首先我們從前置的隨機分佈 pz(z)p_z(z) 中取出 mm 個隨機數 z=(z1,z2,z3, ,zm)z=(z_1,z_2,z_3,\cdots,z_m), 其次我們在從真實數據分佈 p(x)p(x) 中取出 mm 個真實樣本 p=(x1,x2,x3, ,xm)p=(x_1,x_2,x_3,\cdots,x_m).

由於我們的數據是隨機選取的, 概率越大就越有機會被選中. 抽取的樣本就隱含了自身的期望. 因此我們可以使用平均數代替上式中的期望, 公式改寫如下.
V(G,D)=q(x)  log  D(x)  dx+q(z)  log  (1D(G(z)))  dz  =1mi=1mlog  D(xi)+1mi=1mlog  (1D(G(zi))) V(G,D)=\int q(x)\;log\;D(x)\;dx+\int q(z)\;log\;(1-D(G(z)))\;dz\\ \;\\ =\frac{1}{m}\sum_{i=1}^{m}log\;D(x_i) + \frac{1}{m}\sum_{i=1}^{m}log\;(1-D(G(z_i)))

我們可以直接用上式訓練鑑別器 D(x)D(x)​.

在訓練生成器時, 因爲前半部分和 zz 無關, 我們可以只使用後半部分.

全文完.

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章