簡介
上節在講原文GAN的時候,提到我們實際是在用Discriminator來衡量兩個數據的分佈之間的JS divergence,那能不能是其他類型的divergence來衡量真實數據和生成數據之間的差距?又如何進行衡量?(雖然在實作上用不同divergence結果沒有很大差別)
公式輸入請參考:在線Latex公式
f-divergence
任意的divergence都可以用來衡量真實數據和生成數據之間的差距,用f-divergence進行衡量的算法就叫fGAN。先來看看f-divergence的概念:
P P P and Q Q Q are two distributions. p ( x ) p(x) p ( x ) and q ( x ) q(x) q ( x ) are the probability of sampling x x x .
D f ( P ∣ ∣ Q ) = ∫ x q ( x ) f ( p ( x ) q ( x ) ) d x , f is convex, f ( 1 ) = 0 D_f(P||Q)=\int_xq(x)f\left(\cfrac{p(x)}{q(x)}\right)dx,\text{f is convex, }f(1)=0 D f ( P ∣ ∣ Q ) = ∫ x q ( x ) f ( q ( x ) p ( x ) ) d x , f is convex, f ( 1 ) = 0
如果兩個分佈相同,那麼f-divergence的值應該相等,我們來驗證一下:
當p ( x ) = q ( x ) for all x p(x)=q(x) \text{ for all } x p ( x ) = q ( x ) for all x
D f ( P ∣ ∣ Q ) = ∫ x q ( x ) f ( p ( x ) q ( x ) ) d x = 0 D_f(P||Q)=\int_xq(x)f\left(\cfrac{p(x)}{q(x)}\right)dx=0 D f ( P ∣ ∣ Q ) = ∫ x q ( x ) f ( q ( x ) p ( x ) ) d x = 0
因爲:p ( x ) q ( x ) = 1 , f ( 1 ) = 0 \cfrac{p(x)}{q(x)}=1,f(1)=0 q ( x ) p ( x ) = 1 , f ( 1 ) = 0 ,所以divergence爲0,是最小的f-divergence。證明如下:
Because f is convex,因此有(右邊是左邊的lower bound):
D f ( P ∣ ∣ Q ) = ∫ x q ( x ) f ( p ( x ) q ( x ) ) d x ≥ f ( ∫ x q ( x ) p ( x ) q ( x ) d x ) D_f(P||Q)=\int_xq(x)f\left(\cfrac{p(x)}{q(x)}\right)dx\ge f\left(\int_xq(x)\cfrac{p(x)}{q(x)}dx\right) D f ( P ∣ ∣ Q ) = ∫ x q ( x ) f ( q ( x ) p ( x ) ) d x ≥ f ( ∫ x q ( x ) q ( x ) p ( x ) d x )
f ( ∫ x q ( x ) p ( x ) q ( x ) d x ) = f ( ∫ x p ( x ) d x ) = f ( 1 ) = 0 f\left(\int_xq(x)\cfrac{p(x)}{q(x)}dx\right)=f\left(\int_xp(x)dx\right)=f(1)=0 f ( ∫ x q ( x ) q ( x ) p ( x ) d x ) = f ( ∫ x p ( x ) d x ) = f ( 1 ) = 0
如果f f f 是不同的函數,就得到不同的divergence,例如:f ( x ) = x l o g x f(x)=xlogx f ( x ) = x l o g x
D f ( P ∣ ∣ Q ) = ∫ x q ( x ) p ( x ) q ( x ) l o g ( p ( x ) q ( x ) ) d x = ∫ x p ( x ) l o g ( p ( x ) q ( x ) ) d x D_f(P||Q)=\int_xq(x)\cfrac{p(x)}{q(x)}log\left(\cfrac{p(x)}{q(x)}\right)dx\\
=\int_xp(x)log\left(\cfrac{p(x)}{q(x)}\right)dx D f ( P ∣ ∣ Q ) = ∫ x q ( x ) q ( x ) p ( x ) l o g ( q ( x ) p ( x ) ) d x = ∫ x p ( x ) l o g ( q ( x ) p ( x ) ) d x
這個是KL divergence。
例如:f ( x ) = − l o g x f(x)=-logx f ( x ) = − l o g x
D f ( P ∣ ∣ Q ) = ∫ x q ( x ) ( − l o g ( p ( x ) q ( x ) ) ) d x = ∫ x q ( x ) l o g ( q ( x ) p ( x ) ) d x D_f(P||Q)=\int_xq(x)\left(-log\left(\cfrac{p(x)}{q(x)}\right)\right)dx\\
=\int_xq(x)log\left(\cfrac{q(x)}{p(x)}\right)dx D f ( P ∣ ∣ Q ) = ∫ x q ( x ) ( − l o g ( q ( x ) p ( x ) ) ) d x = ∫ x q ( x ) l o g ( p ( x ) q ( x ) ) d x
這個是Reverse KL divergence。
例如:f ( x ) = ( x − 1 ) 2 f(x)=(x-1)^2 f ( x ) = ( x − 1 ) 2
D f ( P ∣ ∣ Q ) = ∫ x q ( x ) ( p ( x ) q ( x ) − 1 ) 2 d x = ∫ x ( p ( x ) − q ( x ) ) 2 q ( x ) d x D_f(P||Q)=\int_xq(x)\left(\cfrac{p(x)}{q(x)}-1\right)^2dx\\
=\int_x\cfrac{\left(p(x)-q(x)\right)^2}{q(x)}dx D f ( P ∣ ∣ Q ) = ∫ x q ( x ) ( q ( x ) p ( x ) − 1 ) 2 d x = ∫ x q ( x ) ( p ( x ) − q ( x ) ) 2 d x
這個是Chi Square divergence。
Fenchel Conjugate
每一個f f f 凸函數都有一個Conjugate函數記爲f ∗ f^* f ∗ ,公式如下:
f ∗ ( t ) = max x ∈ d o m ( f ) { x t − f ( x ) } f^*(t)=\underset{x\in dom(f)}{\text{max}}\{xt-f(x)\} f ∗ ( t ) = x ∈ d o m ( f ) max { x t − f ( x ) }
窮舉所有的t , x t,x t , x ,然後找到能使得x t − f ( x ) xt-f(x) x t − f ( x ) 最大的t , x t,x t , x 。
比較笨的窮舉法如下:
另外一種方法:函數x t − f ( x ) xt-f(x) x t − f ( x ) 是直線,我們帶不同的x x x 得到不同的直線,例如下面有三條直線:
然後找不同的t對應的最大值。(就是所有直線的upper bound)
上面的紅線無論你如何畫,最後都是convex的。
看例子,假設:f ( x ) = x l o g x f(x)=xlogx f ( x ) = x l o g x ,把x = 0.1 , x = 1 , x = 10 x=0.1,x=1,x=10 x = 0 . 1 , x = 1 , x = 1 0 ,圖片如下:
紅線最後接近
f ∗ ( t ) = e x p ( t − 1 ) f^*(t)=exp(t-1) f ∗ ( t ) = e x p ( t − 1 )
下面是數學證明,假設f ( x ) = x l o g x f(x)=xlogx f ( x ) = x l o g x
則:
f ∗ ( t ) = max x ∈ d o m ( f ) { x t − f ( x ) } = max x ∈ d o m ( f ) { x t − x l o g x } (1) f^*(t)=\underset{x\in dom(f)}{\text{max}}\{xt-f(x)\}=\underset{x\in dom(f)}{\text{max}}\{xt-xlogx\}\tag1 f ∗ ( t ) = x ∈ d o m ( f ) max { x t − f ( x ) } = x ∈ d o m ( f ) max { x t − x l o g x } ( 1 )
令上式中x t − x l o g x = g ( x ) , Given t , find x maximizing g ( x ) xt-xlogx=g(x)\text{, Given }t\text{, find }x\text{ maximizing }g(x) x t − x l o g x = g ( x ) , Given t , find x maximizing g ( x )
求極值,就是對x x x 求導數等於0:
g ′ ( x ) = t − l o g x − 1 = 0 x = e x p ( t − 1 ) g'(x)=t-logx-1=0\\
x=exp(t-1) g ′ ( x ) = t − l o g x − 1 = 0 x = e x p ( t − 1 )
把上面內容代入公式(1):
f ∗ ( t ) = x t − x l o g x = e x p ( t − 1 ) × t − e x p ( t − 1 ) × ( t − 1 ) = e x p ( t − 1 ) f^*(t)=xt-xlogx=exp(t-1)\times t-exp(t-1)\times(t-1)=exp(t-1) f ∗ ( t ) = x t − x l o g x = e x p ( t − 1 ) × t − e x p ( t − 1 ) × ( t − 1 ) = e x p ( t − 1 )
一般化後:
( f ∗ ) ∗ = f (f^*)^*=f ( f ∗ ) ∗ = f
講這麼多,下面看下上面兩節內容和GAN的關係
Connection with GAN
通過上面的推導我們知道f ∗ ( t ) f^*(t) f ∗ ( t ) 和f ( x ) f(x) f ( x ) 互爲Conjugate,寫爲:
f ∗ ( t ) = max x ∈ d o m ( f ) { x t − f ( x ) } ← → f ( x ) = max t ∈ d o m ( f ∗ ) { x t − f ∗ ( t ) } f^*(t)=\underset{x\in dom(f)}{\text{max}}\{xt-f(x)\}\leftarrow\rightarrow f(x)=\underset{t\in dom(f^*)}{\text{max}}\{xt-f^*(t)\} f ∗ ( t ) = x ∈ d o m ( f ) max { x t − f ( x ) } ← → f ( x ) = t ∈ d o m ( f ∗ ) max { x t − f ∗ ( t ) }
這兩個相互Conjugate的convex的函數有什麼特別?我們繼續看。
D f ( P ∣ ∣ Q ) = ∫ x q ( x ) f ( p ( x ) q ( x ) ) d x = ∫ x q ( x ) ( max t ∈ d o m ( f ∗ ) { p ( x ) q ( x ) t − f ∗ ( t ) } ) d x D_f(P||Q)=\int_xq(x)f\left(\cfrac{p(x)}{q(x)}\right)dx\\
=\int_xq(x)\left(\underset{t\in dom(f^*)}{\text{max}}\left\{\cfrac{p(x)}{q(x)}t-f^*(t)\right\}\right)dx D f ( P ∣ ∣ Q ) = ∫ x q ( x ) f ( q ( x ) p ( x ) ) d x = ∫ x q ( x ) ( t ∈ d o m ( f ∗ ) max { q ( x ) p ( x ) t − f ∗ ( t ) } ) d x
我們用函數D ( x ) D(x) D ( x ) 代替t t t ,使得輸入x,輸出爲t,使得上面{}中的值最大,替換後,就找到了D f ( P ∣ ∣ Q ) D_f(P||Q) D f ( P ∣ ∣ Q ) 的lower bound
D f ( P ∣ ∣ Q ) ≥ ∫ x q ( x ) ( p ( x ) q ( x ) D ( x ) − f ∗ ( D ( x ) ) ) d x = ∫ x p ( x ) D ( x ) d x − ∫ x q ( x ) f ∗ ( D ( x ) ) d x D_f(P||Q)\ge\int_xq(x)\left(\cfrac{p(x)}{q(x)}D(x)-f^*(D(x))\right)dx\\
=\int_xp(x)D(x)dx-\int_xq(x)f^*(D(x))dx D f ( P ∣ ∣ Q ) ≥ ∫ x q ( x ) ( q ( x ) p ( x ) D ( x ) − f ∗ ( D ( x ) ) ) d x = ∫ x p ( x ) D ( x ) d x − ∫ x q ( x ) f ∗ ( D ( x ) ) d x
當我們找的函數D ( x ) D(x) D ( x ) 是最好的,那麼就可以逼近D f ( P ∣ ∣ Q ) D_f(P||Q) D f ( P ∣ ∣ Q )
D f ( P ∣ ∣ Q ) ≈ max D ∫ x p ( x ) D ( x ) d x − ∫ x q ( x ) f ∗ ( D ( x ) ) d x D_f(P||Q)\approx\underset{D}{\text{max}}\int_xp(x)D(x)dx-\int_xq(x)f^*(D(x))dx D f ( P ∣ ∣ Q ) ≈ D max ∫ x p ( x ) D ( x ) d x − ∫ x q ( x ) f ∗ ( D ( x ) ) d x
積分可以寫成期望值:
= max D { E x ∼ P [ D ( x ) ] − E x ∼ Q [ f ∗ ( D ( x ) ) ] } =\underset{D}{\text{max}}\{E_{x\sim P}[D(x)]-E_{x\sim Q}[f^*(D(x))]\} = D max { E x ∼ P [ D ( x ) ] − E x ∼ Q [ f ∗ ( D ( x ) ) ] }
接下來我們把P = P d a t a , Q = P G P=P_{data},Q=P_G P = P d a t a , Q = P G ,則有:
D f ( P d a t a ∣ ∣ P G ) = max D { E x ∼ P d a t a [ D ( x ) ] − E x ∼ P G [ f ∗ ( D ( x ) ) ] } D_f(P_{data}||P_G)=\underset{D}{\text{max}}\{E_{x\sim P_{data}}[D(x)]-E_{x\sim P_G}[f^*(D(x))]\} D f ( P d a t a ∣ ∣ P G ) = D max { E x ∼ P d a t a [ D ( x ) ] − E x ∼ P G [ f ∗ ( D ( x ) ) ] }
把上面的式子可以帶回求Generator的公式:
G ∗ = a r g min G D f ( P d a t a ∣ ∣ P G ) = a r g min G max D { E x ∼ P d a t a [ D ( x ) ] − E x ∼ P G [ f ∗ ( D ( x ) ) ] } = a r g min G max D V ( G , D ) \begin{aligned}G^* &=arg\underset{G}{\text{min}}D_f(P_{data}||P_G)\\
&=arg\underset{G}{\text{min}}\underset{D}{\text{max}}\{E_{x\sim P_{data}}[D(x)]-E_{x\sim P_G}[f^*(D(x))]\}\\
&=arg\underset{G}{\text{min}}\underset{D}{\text{max}}V(G,D)\end{aligned} G ∗ = a r g G min D f ( P d a t a ∣ ∣ P G ) = a r g G min D max { E x ∼ P d a t a [ D ( x ) ] − E x ∼ P G [ f ∗ ( D ( x ) ) ] } = a r g G min D max V ( G , D )
也就是說我們可以優化不同的divergence(https://arxiv.org/pdf/1606.00709.pdf)
可以用來解決Mode Collapse
Mode Collapse
當我們的GAN模型Training with too many iterations……
有些人臉就會比較像,除了一些顏色不太一樣
從數學上說就是我們的生成對象的分佈越來越小了
Mode Dropping
就是真實數據有兩簇或者多簇,但是生成數據只能生成其中一簇:
例如下面的人臉,一個循環只有白種人,一個循環只有黃種人,一個循環中只有黑人。
問題分析
傳統的做法類似MLE實際上是最小化KL divergence,可以看到生成數據的分佈是在真實數據中間,這也是爲什麼真實數據這麼模糊
如果換成Reverse KL Divergence:
可以看到解決了模糊的問題,但是會出現Mode Dropping的問題。
因此選擇不同的divergence對於GAN很重要。
解決Mode Collapse
Unsemble:訓練多個generator,然後隨機調一個generator來生成圖片,這樣結果就會比較diverse。
Train a set of generators: { G 1 , G 2 , … , G N } \{G_1,G_2,…,G_N\} { G 1 , G 2 , … , G N }
To generate an image Random pick a generator G i G_i G i . Use G i G_i G i generate the image.