GAN and CGAN on MNIST dataset

GAN and CGAN

關於GAN和CGAN的教程網上有好多,感興趣的可以自己去找着看看。最重要的是要弄清楚,GAN是幹嘛的?作者在論文和NIPS2016 tutorial裏面都講了很多設計GAN的初衷。簡單來說,GAN就是用來擬合樣本的分佈的。

看了論文總覺得理解得不夠透徹,就在網上找了一些簡單的程序跑了跑,下面就直接來說說程序和實驗結果吧。


GAN

GAN的結構如下圖所示。
這裏寫圖片描述

實驗數據集採用了MNIST手寫數字圖像。代碼參考github上的例子,使用TensorFlow1.2搭建了一個神經網絡。準確來說是兩個網絡,判別器Discriminator和生成器Generator各是一個神經網絡。兩個網絡的結構類似,都有兩個全連接層,最後一個全連接層接輸出的tanh或者sigmoid層。

code

代碼參考了別人的,原作者是在ipython notebook裏面寫的代碼,然後自己畫圖。我是使用的Tensorboard來進行可視化操作的,所以我把代碼進行了相應的修改。下面只顯示了主要的程序部分。請看完整代碼,順便說一句,這裏還可以看到無水印的原圖~

# reference https://github.com/NELSONZHAO/zhihu/tree/master/mnist_gan

def get_generator(noise_img, n_units, out_dim, reuse = False, alpha = 0.01):
    """
    generator

    noise_img: input of generator
    n_units: # hidden units
    out_dim: # output
    alpha: parameter of leaky ReLU
    """
    with tf.variable_scope("generator", reuse = reuse):
        # hidden layer
        hidden = tf.layers.dense(noise_img, n_units)
        # leaky ReLU
        relu = tf.maximum(alpha * hidden, hidden)
        # dropout
        drop = tf.layers.dropout(relu, rate = 0.5)

        # logits & outputs
        logits = tf.layers.dense(drop, out_dim)
        outputs = tf.tanh(logits)

        return logits, outputs


def get_discriminator(img, n_units, reuse = False, alpha = 0.01):
    """
    discriminator

    n_units: # hidden units
    alpha: parameter of leaky Relu
    """
    with tf.variable_scope("discriminator", reuse=reuse):
        # hidden layer
        hidden = tf.layers.dense(img, n_units)
        relu = tf.maximum(alpha * hidden, hidden)
        # dropout
        drop = tf.layers.dropout(relu, rate = 0.5)

        # logits & outputs
        logits = tf.layers.dense(drop, 1)
        outputs = tf.sigmoid(logits)

        return logits, outputs

...
with tf.Graph().as_default():

    real_img, noise_img = get_inputs(real_img_size, noise_img_size)

    # generator
    g_logits, g_outputs = get_generator(noise_img, g_units, real_img_size)

    sample_images = tf.reshape(g_outputs, [-1, 28, 28, 1])
    tf.summary.image("sample_images", sample_images, 10)

    # discriminator
    d_logits_real, d_outputs_real = get_discriminator(real_img, d_units)
    d_logits_fake, d_outputs_fake = get_discriminator(g_outputs, d_units, reuse = True)


    # discriminator loss
    d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = d_logits_real, 
        labels = tf.ones_like(d_logits_real)) * (1 - smooth))
    d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = d_logits_fake, 
        labels = tf.zeros_like(d_logits_fake)))
    # loss
    d_loss = tf.add(d_loss_real, d_loss_fake)

    # generator loss
    g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = d_logits_fake,
        labels = tf.ones_like(d_logits_fake)) * (1 - smooth) )

    ...

實驗結果

我使用了TF的tensorboard來進行可視化,通過tensorboard可以看到生成的網絡的結構。同時我也保存了Generator在訓練的過程中生成的樣本圖片,還有Generator和Discriminator的Loss。

Graph

這裏寫圖片描述

Loss

這裏寫圖片描述

這裏寫圖片描述

把最後30個epoch的Loss值顯示出來

('Epoch 271/300', 'Discriminator loss: 0.8802(Real: 0.5291 + Fake: 0.3511)', 'Generator loss: 2.0067')
('Epoch 272/300', 'Discriminator loss: 0.8302(Real: 0.4576 + Fake: 0.3727)', 'Generator loss: 1.8738')
('Epoch 273/300', 'Discriminator loss: 0.7560(Real: 0.3527 + Fake: 0.4033)', 'Generator loss: 1.8327')
('Epoch 274/300', 'Discriminator loss: 0.7337(Real: 0.3728 + Fake: 0.3609)', 'Generator loss: 1.9727')
('Epoch 275/300', 'Discriminator loss: 0.7968(Real: 0.4157 + Fake: 0.3811)', 'Generator loss: 1.9415')
('Epoch 276/300', 'Discriminator loss: 0.9710(Real: 0.2390 + Fake: 0.7320)', 'Generator loss: 1.2463')
('Epoch 277/300', 'Discriminator loss: 0.8384(Real: 0.3923 + Fake: 0.4460)', 'Generator loss: 1.7710')
('Epoch 278/300', 'Discriminator loss: 0.8676(Real: 0.5134 + Fake: 0.3542)', 'Generator loss: 2.0165')
('Epoch 279/300', 'Discriminator loss: 0.7826(Real: 0.3927 + Fake: 0.3899)', 'Generator loss: 1.8694')
('Epoch 280/300', 'Discriminator loss: 0.8702(Real: 0.5518 + Fake: 0.3184)', 'Generator loss: 2.0910')
('Epoch 281/300', 'Discriminator loss: 0.9124(Real: 0.4497 + Fake: 0.4627)', 'Generator loss: 1.6900')
('Epoch 282/300', 'Discriminator loss: 0.9092(Real: 0.4113 + Fake: 0.4979)', 'Generator loss: 1.5553')
('Epoch 283/300', 'Discriminator loss: 0.8640(Real: 0.3815 + Fake: 0.4825)', 'Generator loss: 1.6243')
('Epoch 284/300', 'Discriminator loss: 0.8872(Real: 0.4649 + Fake: 0.4223)', 'Generator loss: 1.7229')
('Epoch 285/300', 'Discriminator loss: 0.8070(Real: 0.4819 + Fake: 0.3250)', 'Generator loss: 2.1052')
('Epoch 286/300', 'Discriminator loss: 0.8917(Real: 0.3816 + Fake: 0.5100)', 'Generator loss: 1.6345')
('Epoch 287/300', 'Discriminator loss: 0.8642(Real: 0.3220 + Fake: 0.5422)', 'Generator loss: 1.5187')
('Epoch 288/300', 'Discriminator loss: 0.8012(Real: 0.3443 + Fake: 0.4569)', 'Generator loss: 1.7401')
('Epoch 289/300', 'Discriminator loss: 0.7584(Real: 0.4139 + Fake: 0.3445)', 'Generator loss: 2.0334')
('Epoch 290/300', 'Discriminator loss: 0.9034(Real: 0.3890 + Fake: 0.5144)', 'Generator loss: 1.6497')
('Epoch 291/300', 'Discriminator loss: 0.8898(Real: 0.5868 + Fake: 0.3030)', 'Generator loss: 2.1135')
('Epoch 292/300', 'Discriminator loss: 0.9474(Real: 0.3832 + Fake: 0.5642)', 'Generator loss: 1.5061')
('Epoch 293/300', 'Discriminator loss: 0.7980(Real: 0.3361 + Fake: 0.4619)', 'Generator loss: 1.7169')
('Epoch 294/300', 'Discriminator loss: 0.9076(Real: 0.5326 + Fake: 0.3750)', 'Generator loss: 1.9832')
('Epoch 295/300', 'Discriminator loss: 0.8456(Real: 0.4473 + Fake: 0.3983)', 'Generator loss: 1.9126')
('Epoch 296/300', 'Discriminator loss: 0.8395(Real: 0.4567 + Fake: 0.3828)', 'Generator loss: 1.8762')
('Epoch 297/300', 'Discriminator loss: 0.8388(Real: 0.5081 + Fake: 0.3306)', 'Generator loss: 2.0346')
('Epoch 298/300', 'Discriminator loss: 0.8307(Real: 0.4592 + Fake: 0.3715)', 'Generator loss: 1.9360')
('Epoch 299/300', 'Discriminator loss: 0.8607(Real: 0.4490 + Fake: 0.4117)', 'Generator loss: 1.8217')
('Epoch 300/300', 'Discriminator loss: 0.8606(Real: 0.6024 + Fake: 0.2582)', 'Generator loss: 2.3584')

Images

爲了進一步瞭解Generator的生產效果,我們還可以通過Tensorboard保存訓練過程中生產的圖片,然後顯示出來看看。從下面的圖中可以看到,299個epoch生成的圖像很接近真實的圖像,但是個別圖像還是會有一些噪聲,甚至根本看不出來是數字。不過總的來說已經很好了。

195th Epoch
這裏寫圖片描述

299th Epoch
這裏寫圖片描述


CGAN

如下圖所示爲CGAN網絡結構,作者的論文中的思想其實很簡單,就是在輸入的時候加上標籤信息,Discriminator 和 Generator都加。在代碼上也可以看到,改動非常小,只需要在網絡的輸入中把圖樣的labels信息和圖片像素連接起來輸入給網絡就好了。

這裏寫圖片描述

code

代碼和GAN基本一樣,我就只把改動了的部分列出來,感興趣的請看完整代碼


# reference https://github.com/NELSONZHAO/zhihu/tree/master/mnist_gan

def get_inputs(real_img_size, noise_img_size):
    """
    read image tensor and noise image tensor, as well as image digit
    """
    ...
    real_img_digit = tf.placeholder(tf.float32, shape = [None, k])

    ...
    return real_img, noise_img, real_img_digit

def get_generator(digit, noise_img, n_units, out_dim, reuse = False, alpha = 0.01):
    """
    generator

    noise_img: input of generator
    n_units: # hidden units
    out_dim: # output
    alpha: parameter of leaky ReLU
    """
    with tf.variable_scope("generator", reuse = reuse):
        concatenated_img_digit = tf.concat([digit, noise_img], 1)
        # hidden layer
        hidden = tf.layers.dense(concatenated_img_digit, n_units)

        ...

def get_discriminator(digit, img, n_units, reuse = False, alpha = 0.01):
    """
    discriminator

    n_units: # hidden units
    alpha: parameter of leaky Relu
    """
    with tf.variable_scope("discriminator", reuse=reuse):
        concatenated_img_digit = tf.concat([digit, img], 1)
        # hidden layer
        hidden = tf.layers.dense(concatenated_img_digit, n_units)
        ...


with tf.Graph().as_default():

    real_img, noise_img, real_img_digit = get_inputs(real_img_size, noise_img_size)

    # generator
    g_logits, g_outputs = get_generator(real_img_digit, noise_img, g_units, real_img_size)

    # discriminator
    d_logits_real, d_outputs_real = get_discriminator(real_img_digit, real_img, d_units)
    d_logits_fake, d_outputs_fake = get_discriminator(real_img_digit, g_outputs, d_units, reuse = True)

    ...

    for e in xrange(epochs):
        for i in xrange(mnist.train.num_examples//(batch_size * k)):
            for j in xrange(k):
                ...
                digits = batch[1]
                ...

                # Run optimizer
                sess.run([d_train_opt, g_train_opt],
                    feed_dict = {real_img: images, noise_img: noises, real_img_digit: digits})

        # train loss
        digits = mnist.train.labels

        summary_str, train_loss_d_real, train_loss_d_fake, train_loss_g = \
            sess.run([summary, d_loss_real, d_loss_fake, g_loss],
            feed_dict = {real_img: images, noise_img: noises, real_img_digit: digits})

        ...



實驗結果

TensorFLow Graph

圖中有兩個Discriminator,是因爲我們在訓練的過程中,把真實的樣本和生成的假樣本分別輸入給Discriminator了,但是這兩個D是共享同樣的參數的,也就是說實際上是隻有一個D的,只不過針對我們的操作,TF生成了兩個D方便我們可視化。建議大家自己跑跑程序,然後用Tenforboard看看。

這裏寫圖片描述

Loss

訓練到後面Discriminator的Loss基本在1.0左右,Generator的Loss在1.3左右。如果大家自己跑了程序,還可以看到Discriminator的real和fake兩部分的Loss到後面穩定在0.5左右,說明D已經分不出來圖像是真的還是假的開始隨機猜了。

這裏寫圖片描述

這裏寫圖片描述

這裏再把最後30個epoch的Loss值顯示出來

('Epoch 271/300', 'Discriminator loss: 1.1468(Real: 0.6671 + Fake: 0.4798)', 'Generator loss: 1.3128')
('Epoch 272/300', 'Discriminator loss: 1.1129(Real: 0.5637 + Fake: 0.5491)', 'Generator loss: 1.1645')
('Epoch 273/300', 'Discriminator loss: 1.0506(Real: 0.4907 + Fake: 0.5599)', 'Generator loss: 1.1404')
('Epoch 274/300', 'Discriminator loss: 0.9998(Real: 0.5912 + Fake: 0.4086)', 'Generator loss: 1.4089')
('Epoch 275/300', 'Discriminator loss: 1.0099(Real: 0.5183 + Fake: 0.4916)', 'Generator loss: 1.2781')
('Epoch 276/300', 'Discriminator loss: 1.0101(Real: 0.5079 + Fake: 0.5022)', 'Generator loss: 1.2244')
('Epoch 277/300', 'Discriminator loss: 1.1071(Real: 0.5453 + Fake: 0.5618)', 'Generator loss: 1.1359')
('Epoch 278/300', 'Discriminator loss: 1.1023(Real: 0.5469 + Fake: 0.5554)', 'Generator loss: 1.1729')
('Epoch 279/300', 'Discriminator loss: 1.0523(Real: 0.5601 + Fake: 0.4922)', 'Generator loss: 1.2823')
('Epoch 280/300', 'Discriminator loss: 1.0357(Real: 0.5382 + Fake: 0.4976)', 'Generator loss: 1.2529')
('Epoch 281/300', 'Discriminator loss: 1.0170(Real: 0.4792 + Fake: 0.5378)', 'Generator loss: 1.1726')
('Epoch 282/300', 'Discriminator loss: 1.0637(Real: 0.6081 + Fake: 0.4556)', 'Generator loss: 1.3511')
('Epoch 283/300', 'Discriminator loss: 0.9885(Real: 0.3953 + Fake: 0.5932)', 'Generator loss: 1.0948')
('Epoch 284/300', 'Discriminator loss: 1.0381(Real: 0.4418 + Fake: 0.5963)', 'Generator loss: 1.1371')
('Epoch 285/300', 'Discriminator loss: 1.0807(Real: 0.4333 + Fake: 0.6475)', 'Generator loss: 1.0169')
('Epoch 286/300', 'Discriminator loss: 1.0113(Real: 0.5422 + Fake: 0.4691)', 'Generator loss: 1.3784')
('Epoch 287/300', 'Discriminator loss: 1.1880(Real: 0.7804 + Fake: 0.4076)', 'Generator loss: 1.4500')
('Epoch 288/300', 'Discriminator loss: 1.0233(Real: 0.5578 + Fake: 0.4655)', 'Generator loss: 1.3086')
('Epoch 289/300', 'Discriminator loss: 0.9450(Real: 0.5309 + Fake: 0.4141)', 'Generator loss: 1.4525')
('Epoch 290/300', 'Discriminator loss: 1.0533(Real: 0.5605 + Fake: 0.4928)', 'Generator loss: 1.3048')
('Epoch 291/300', 'Discriminator loss: 1.1212(Real: 0.6763 + Fake: 0.4449)', 'Generator loss: 1.3772')
('Epoch 292/300', 'Discriminator loss: 0.9949(Real: 0.5823 + Fake: 0.4126)', 'Generator loss: 1.4838')
('Epoch 293/300', 'Discriminator loss: 1.0056(Real: 0.5089 + Fake: 0.4967)', 'Generator loss: 1.2450')
('Epoch 294/300', 'Discriminator loss: 0.9519(Real: 0.4782 + Fake: 0.4738)', 'Generator loss: 1.3058')
('Epoch 295/300', 'Discriminator loss: 1.0492(Real: 0.5850 + Fake: 0.4641)', 'Generator loss: 1.3339')
('Epoch 296/300', 'Discriminator loss: 1.0828(Real: 0.6848 + Fake: 0.3980)', 'Generator loss: 1.4584')
('Epoch 297/300', 'Discriminator loss: 1.0169(Real: 0.5184 + Fake: 0.4985)', 'Generator loss: 1.2519')
('Epoch 298/300', 'Discriminator loss: 0.9329(Real: 0.4502 + Fake: 0.4828)', 'Generator loss: 1.3602')
('Epoch 299/300', 'Discriminator loss: 0.9473(Real: 0.5277 + Fake: 0.4196)', 'Generator loss: 1.4321')
('Epoch 300/300', 'Discriminator loss: 1.0214(Real: 0.6331 + Fake: 0.3883)', 'Generator loss: 1.4634')

Sample Images

從下面的圖片中可以看到,Generator在195個Epoch之後,就已經可以生成比較好的圖像了,只有少量的噪聲。而第299個Epoch之後,生成的圖片就完全沒有噪聲了,就算是人也已經難以分辨真假了。看來CGAN的提升效果還是很明顯的!

195th Epoch
這裏寫圖片描述

299th Epoch
這裏寫圖片描述


Reference

1. http://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf
2. https://github.com/NELSONZHAO/zhihu/tree/master/mnist_gan
3. https://arxiv.org/abs/1411.1784

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章