生成對抗網絡(GAN) 手寫數字圖像生成

生成式對抗網絡(GAN)

簡介

生成式對抗網絡的框架主要有兩個模型,一個是生成模型(Generator),記爲爲 G,是用來生成數據,通過大量的樣本學習,生成一些能夠以假亂真的數據樣本。第二個是辨別模型(Discriminator),記爲D,主要是接受 G生成的樣本數據和真實樣本數據,進行辨別和分類。生成網絡G接受一個隨機的噪聲z並生成圖片,記爲G(z);判別網絡D的作用是判別一張圖片x是否真實,對於輸入x,D(x)是x爲真實圖片的概率。G和D相互博弈,通過學習,G的生成能力和D的辨別能力逐漸增強直到收斂。

原理

一個隨機生成符合隨機分佈的噪音 z, 生成器G通過一個複雜的映射關係生成假樣本

                                                                                           \hat{x}=G(z;\Theta g)

辨別器對於真實樣本和假的樣本,輸出一個0到1之間的值,越大就越有可能是真實樣本

                                                                                            s=D(x;\Theta d)

總的目標函數

                                   

代碼

# encoding: utf-8
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data/')
#設置一些參數
batch_size = 100
z_dim = 100
OUTPUT_DIR = 'samples'
if not os.path.exists(OUTPUT_DIR):
    os.mkdir(OUTPUT_DIR)

X = tf.placeholder(dtype = tf.float32, shape = [None, 28,28, 1], name = 'X')
Noise = tf.placeholder(dtype = tf.float32, shape = [None, z_dim], name = 'Noise')
is_training = tf.placeholder(dtype = tf.bool, name = 'is_training')

def relu(x, leak = 0.2):
    return tf.maximum(x, leak * x)

def sigmoid_cross_entropy_with_logits(x, y):
    return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, labels=y)

#辨別器
def discriminator(image,reuse = None, is_training = is_training):
    m = 0.9
    with tf.variable_scope('discriminator', reuse = reuse):
        H0 = relu(tf.layers.conv2d(image, kernel_size = 5, filters = 64, strides=2, padding ='same'))

        H1 = tf.layers.conv2d(H0, kernel_size = 5, filters = 128, strides = 2,padding = 'same')
        H1 = relu(tf.contrib.layers.batch_norm(H1, is_training=is_training, decay = m))

        H2 = tf.layers.conv2d(H1, kernel_size = 5, filters = 256, strides = 2, padding = 'same')
        H2 = relu(tf.contrib.layers.batch_norm(H2, is_training = is_training, decay = m))

        H3 = tf.layers.conv2d(H2, kernel_size = 5 , filters = 512, strides = 2, padding = 'same')
        H3 = relu(tf.contrib.layers.batch_norm(H3, is_training = is_training, decay = m))

        H4 = tf.contrib.layers.flatten(H3)
        H4 = tf.layers.dense(H4, units=1)
        return tf.nn.sigmoid(H4), H4


def generator(z, is_training = is_training):
    m = 0.8

    with tf.variable_scope('generator', reuse = None):
        d = 3
        H0 = tf.layers.dense(z, units = d*d*512)
        H0 = tf.reshape(H0, shape = [-1, d, d, 512])
        H0 = tf.nn.relu(tf.contrib.layers.batch_norm(H0, is_training=is_training, decay = m))

        H1 = tf.layers.conv2d_transpose(H0, kernel_size = 5, filters = 256, strides = 2, padding = 'same')
        H1 = tf.nn.relu(tf.contrib.layers.batch_norm(H1, is_training = is_training, decay = m))

        H2 = tf.layers.conv2d_transpose(H1, kernel_size = 5, filters = 128,strides = 2, padding = 'same')
        H2 = tf.nn.relu(tf.contrib.layers.batch_norm(H2 , is_training=is_training, decay = m ))

        H3 = tf.layers.conv2d_transpose(H2, kernel_size = 5, filters = 64, strides = 2, padding = 'same')
        H3 = tf.nn.relu(tf.contrib.layers.batch_norm(H3, is_training = is_training, decay = m))

        H4 = tf.layers.conv2d_transpose(H3, kernel_size = 5, filters= 1, strides = 1, padding = 'valid', activation=tf.nn.tanh, name = 'g')

        return H4
    
g = generator(Noise)

d_real,d_real_logits = discriminator(X)
d_fake, d_fake_logits = discriminator(g, reuse = True)


vars_g = [var for var in tf.trainable_variables() if var.name.startswith('generator')]
vars_d = [var for var in tf.trainable_variables() if var.name.startswith('discriminator')]

loss_d_real = tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_real_logits,tf.ones_like(d_real)))
loss_d_fake = tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_fake_logits, tf.zeros_like(d_fake)))

loss_g = tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_fake_logits,tf.ones_like(d_fake)))
loss_d = loss_d_real + loss_d_fake

updates_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(updates_ops):

    optimizer_d = tf.train.AdamOptimizer(learning_rate = 0.0002, beta1 = 0.5).minimize(loss_d, var_list=vars_d)
    optimizer_g = tf.train.AdamOptimizer(learning_rate = 0.0002, beta1 = 0.5).minimize(loss_g, var_list = vars_g)

def montage(images):
    if isinstance(images,list):
        images = np.array(images)

    image_h = images.shape[1]
    image_w = images.shape[2]
    n_plots = int(np.ceil(np.sqrt(images.shape[0])))
    m = np.ones((images.shape[1] * n_plots + n_plots +1 , images.shape[2] * n_plots + n_plots + 1))  * 0.5
    for i in range(n_plots):
        for j in range(n_plots):
            this_filter = i*n_plots + j
            if this_filter < images.shape[0]:
                this_img = images[this_filter]
                m[1 + i + i*image_h : 1 + i+ (i+1) *image_h,1+j+ j*image_w : 1+ j+(j+1)*image_w] = this_img
    return m

sess = tf.Session()
sess.run(tf.global_variables_initializer())
z_samlpes = np.random.uniform(-1.0, 1.0,[batch_size,z_dim]).astype(np.float32)

samples = []
loss = {'d':[], 'g':[]}
for i in range(30000):
    n = np.random.uniform(-1.0,1.0,[batch_size,z_dim]).astype(np.float32)
    batch = mnist.train.next_batch(batch_size=batch_size)[0]
    batch = np.reshape(batch,[-1,28,28,1])
    batch = (batch - 0.5) * 2
    d_ls,g_ls = sess.run([loss_d, loss_g], feed_dict={X:batch,Noise:n, is_training:True})
    loss['d'].append(d_ls)
    loss['g'].append(g_ls)

    sess.run(optimizer_d, feed_dict={X:batch,Noise:n,is_training:True})
    sess.run(optimizer_g, feed_dict={X:batch,Noise:n,is_training:True})
    sess.run(optimizer_g, feed_dict={X:batch, Noise:n,is_training:True})

    if i % 20 == 0:
        print(i,d_ls, g_ls)
        gen_imgs = sess.run(g, feed_dict={Noise:z_samlpes,is_training:False})
        gen_imgs = (gen_imgs + 1) /2
        imgs = [img[:,:,0] for img in gen_imgs]
        gen_imgs = montage(imgs)
        plt.axis('off')
        plt.imshow(gen_imgs, cmap = 'gray')
        plt.savefig(os.path.join(OUTPUT_DIR,'sample_%d.jpg'%i))
        plt.show()
        samples.append(gen_imgs)

plt.plot(loss['d'], label='discriminator')
plt.plot(loss['g'], label = 'generator')
plt.legend(loc = 'upper right')
plt.show()
saver =tf.train.Saver()
saver.save(sess, './mnist_dcgan', global_step=30000)

結果

參考:https://zhuanlan.zhihu.com/p/44167207

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章