tensorflow(九)生成式對抗網絡(GAN)上篇----簡介與算法原理

一、簡介

生成式對抗網絡(GAN, Generative Adversarial Networks )是一種深度學習模型,是近年來複雜分佈上無監督學習最具前景的方法之一。
模型通過框架中(至少)兩個模塊:生成模型(Generative Model)和判別模型(Discriminative Model)的互相博弈學習產生相當好的輸出。原始 GAN 理論中,並不要求 G 和 D 都是神經網絡,只需要是能擬合相應生成和判別的函數即可。但實用中一般均使用深度神經網絡作爲 G 和 D 。一個優秀的GAN應用需要有良好的訓練方法,否則可能由於神經網絡模型的自由性而導致輸出不理想。
下面是使用GAN生成的可愛的人臉頭像。(圖片來源於:http://baijiahao.baidu.com/s?id=1568663805038898&wfr=spider&for=pc
這裏寫圖片描述
可以看出生成的圖片相似度很高,很有趣。

二、GAN原理

GAN 主要包括了兩個部分,即生成器 generator 與判別器 discriminator。生成器主要用來學習真實圖像分佈從而讓自身生成的圖像更加真實,以騙過判別器。判別器則需要對接收的圖片進行真假判別。在整個過程中,生成器努力地讓生成的圖像更加真實,而判別器則努力地去識別出圖像的真假,這個過程相當於一個二人博弈,隨着時間的推移,生成器和判別器在不斷地進行對抗,最終兩個網絡達到了一個動態均衡:生成器生成的圖像接近於真實圖像分佈,而判別器識別不出真假圖像,對於給定圖像的預測爲真的概率基本接近 0.5(相當於隨機猜測類別)。

三、GAN的應用

1、圖像生成
目前GAN最常使用的地方就是圖像生成,如超分辨率任務,語義分割等等。
2、數據增強
用GAN生成的圖像來做數據增強,主要解決對於小數據集,數據量不足的情況

四、GAN詳解

公式和圖片來源於原論文,論文地址爲:https://arxiv.org/pdf/1406.2661.pdf

1、GAN最終需要優化的目標函數如下:
這裏寫圖片描述
公式中,x表示真實圖片,z表示輸入G網絡的噪聲,而G(z)表示G網絡生成的圖片。
D(x)表示D網絡判斷真實圖片是否真實的概率(因爲x就是真實的,所以對於D來說,這個值越接近1越好)。而D(G(z))是D網絡判斷G生成的圖片的是否真實的概率。
G的目的:D(G(z))是D網絡判斷G生成的圖片是否真實的概率,G應該希望自己生成的圖片“越接近真實越好”。也就是說,G希望D(G(z))儘可能得大,這時V(D, G)會變小。因此我們看到式子的最前面的記號是min_G。
D的目的:D的能力越強,D(x)應該越大,D(G(x))應該越小。這時V(D,G)會變大。因此式子對於D來說是求最大,記爲max_D。

2、算法流程圖如下:
這裏寫圖片描述
3、算法最終實現的結果如下圖所示:
這裏寫圖片描述
圖中虛線代表真實數據的分佈,這裏爲高斯分佈,紅色的線代表隨機初始化的噪聲。GAN的目標便是通過不斷的訓練D和G網絡,最終使紅色曲線逐漸擬合虛線,讓D網絡傻傻分不清楚兩條線。即由圖一到圖四的過程。

五、關於GAN的一些論文

這裏寫圖片描述
這裏寫圖片描述

六、tensorflow實現過程

    問題求解大概步驟:
    1、定義D、G網絡的結構,定義輸入的shape,把每層的shape和初始化方式定義好,定義D和G兩個網絡的loss
    2、train訓練 傳入兩組數據,真實的數據x和隨機初始化的數據z,並進行迭代優化求解
    3、可視化打印輸出

由於篇幅原因,詳細代碼和註釋見下一篇博文。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章