生成式對抗網絡(GAN)
簡介
生成式對抗網絡的框架主要有兩個模型,一個是生成模型(Generator),記爲爲 G,是用來生成數據,通過大量的樣本學習,生成一些能夠以假亂真的數據樣本。第二個是辨別模型(Discriminator),記爲D,主要是接受 G生成的樣本數據和真實樣本數據,進行辨別和分類。生成網絡G接受一個隨機的噪聲z並生成圖片,記爲G(z);判別網絡D的作用是判別一張圖片x是否真實,對於輸入x,D(x)是x爲真實圖片的概率。G和D相互博弈,通過學習,G的生成能力和D的辨別能力逐漸增強直到收斂。
原理
一個隨機生成符合隨機分佈的噪音 z, 生成器G通過一個複雜的映射關係生成假樣本
辨別器對於真實樣本和假的樣本,輸出一個0到1之間的值,越大就越有可能是真實樣本
總的目標函數
代碼
# encoding: utf-8
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data/')
#設置一些參數
batch_size = 100
z_dim = 100
OUTPUT_DIR = 'samples'
if not os.path.exists(OUTPUT_DIR):
os.mkdir(OUTPUT_DIR)
X = tf.placeholder(dtype = tf.float32, shape = [None, 28,28, 1], name = 'X')
Noise = tf.placeholder(dtype = tf.float32, shape = [None, z_dim], name = 'Noise')
is_training = tf.placeholder(dtype = tf.bool, name = 'is_training')
def relu(x, leak = 0.2):
return tf.maximum(x, leak * x)
def sigmoid_cross_entropy_with_logits(x, y):
return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, labels=y)
#辨別器
def discriminator(image,reuse = None, is_training = is_training):
m = 0.9
with tf.variable_scope('discriminator', reuse = reuse):
H0 = relu(tf.layers.conv2d(image, kernel_size = 5, filters = 64, strides=2, padding ='same'))
H1 = tf.layers.conv2d(H0, kernel_size = 5, filters = 128, strides = 2,padding = 'same')
H1 = relu(tf.contrib.layers.batch_norm(H1, is_training=is_training, decay = m))
H2 = tf.layers.conv2d(H1, kernel_size = 5, filters = 256, strides = 2, padding = 'same')
H2 = relu(tf.contrib.layers.batch_norm(H2, is_training = is_training, decay = m))
H3 = tf.layers.conv2d(H2, kernel_size = 5 , filters = 512, strides = 2, padding = 'same')
H3 = relu(tf.contrib.layers.batch_norm(H3, is_training = is_training, decay = m))
H4 = tf.contrib.layers.flatten(H3)
H4 = tf.layers.dense(H4, units=1)
return tf.nn.sigmoid(H4), H4
def generator(z, is_training = is_training):
m = 0.8
with tf.variable_scope('generator', reuse = None):
d = 3
H0 = tf.layers.dense(z, units = d*d*512)
H0 = tf.reshape(H0, shape = [-1, d, d, 512])
H0 = tf.nn.relu(tf.contrib.layers.batch_norm(H0, is_training=is_training, decay = m))
H1 = tf.layers.conv2d_transpose(H0, kernel_size = 5, filters = 256, strides = 2, padding = 'same')
H1 = tf.nn.relu(tf.contrib.layers.batch_norm(H1, is_training = is_training, decay = m))
H2 = tf.layers.conv2d_transpose(H1, kernel_size = 5, filters = 128,strides = 2, padding = 'same')
H2 = tf.nn.relu(tf.contrib.layers.batch_norm(H2 , is_training=is_training, decay = m ))
H3 = tf.layers.conv2d_transpose(H2, kernel_size = 5, filters = 64, strides = 2, padding = 'same')
H3 = tf.nn.relu(tf.contrib.layers.batch_norm(H3, is_training = is_training, decay = m))
H4 = tf.layers.conv2d_transpose(H3, kernel_size = 5, filters= 1, strides = 1, padding = 'valid', activation=tf.nn.tanh, name = 'g')
return H4
g = generator(Noise)
d_real,d_real_logits = discriminator(X)
d_fake, d_fake_logits = discriminator(g, reuse = True)
vars_g = [var for var in tf.trainable_variables() if var.name.startswith('generator')]
vars_d = [var for var in tf.trainable_variables() if var.name.startswith('discriminator')]
loss_d_real = tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_real_logits,tf.ones_like(d_real)))
loss_d_fake = tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_fake_logits, tf.zeros_like(d_fake)))
loss_g = tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_fake_logits,tf.ones_like(d_fake)))
loss_d = loss_d_real + loss_d_fake
updates_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(updates_ops):
optimizer_d = tf.train.AdamOptimizer(learning_rate = 0.0002, beta1 = 0.5).minimize(loss_d, var_list=vars_d)
optimizer_g = tf.train.AdamOptimizer(learning_rate = 0.0002, beta1 = 0.5).minimize(loss_g, var_list = vars_g)
def montage(images):
if isinstance(images,list):
images = np.array(images)
image_h = images.shape[1]
image_w = images.shape[2]
n_plots = int(np.ceil(np.sqrt(images.shape[0])))
m = np.ones((images.shape[1] * n_plots + n_plots +1 , images.shape[2] * n_plots + n_plots + 1)) * 0.5
for i in range(n_plots):
for j in range(n_plots):
this_filter = i*n_plots + j
if this_filter < images.shape[0]:
this_img = images[this_filter]
m[1 + i + i*image_h : 1 + i+ (i+1) *image_h,1+j+ j*image_w : 1+ j+(j+1)*image_w] = this_img
return m
sess = tf.Session()
sess.run(tf.global_variables_initializer())
z_samlpes = np.random.uniform(-1.0, 1.0,[batch_size,z_dim]).astype(np.float32)
samples = []
loss = {'d':[], 'g':[]}
for i in range(30000):
n = np.random.uniform(-1.0,1.0,[batch_size,z_dim]).astype(np.float32)
batch = mnist.train.next_batch(batch_size=batch_size)[0]
batch = np.reshape(batch,[-1,28,28,1])
batch = (batch - 0.5) * 2
d_ls,g_ls = sess.run([loss_d, loss_g], feed_dict={X:batch,Noise:n, is_training:True})
loss['d'].append(d_ls)
loss['g'].append(g_ls)
sess.run(optimizer_d, feed_dict={X:batch,Noise:n,is_training:True})
sess.run(optimizer_g, feed_dict={X:batch,Noise:n,is_training:True})
sess.run(optimizer_g, feed_dict={X:batch, Noise:n,is_training:True})
if i % 20 == 0:
print(i,d_ls, g_ls)
gen_imgs = sess.run(g, feed_dict={Noise:z_samlpes,is_training:False})
gen_imgs = (gen_imgs + 1) /2
imgs = [img[:,:,0] for img in gen_imgs]
gen_imgs = montage(imgs)
plt.axis('off')
plt.imshow(gen_imgs, cmap = 'gray')
plt.savefig(os.path.join(OUTPUT_DIR,'sample_%d.jpg'%i))
plt.show()
samples.append(gen_imgs)
plt.plot(loss['d'], label='discriminator')
plt.plot(loss['g'], label = 'generator')
plt.legend(loc = 'upper right')
plt.show()
saver =tf.train.Saver()
saver.save(sess, './mnist_dcgan', global_step=30000)
結果