前言
鉴于网上对于GAN网络代码的结构不太好,甚至无法做到迭代,我决定就GAN网络来写一个拥有能够一目了然的完整结构的代码,以帮助那些和我一样刚开始接触这类网络的人,本篇中的GAN网络由全连接层组成,以此来复现最简单的GAN网络结构。
一、代码结构
代码由全局量、生成器、判别器、GAN网络、训练、范例图片生成以及载入模型生成图片这几个结构组成
二、代码
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers,optimizers,losses
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.python.keras import backend as K
from tensorflow.keras.utils import plot_model
from IPython.display import Image
import cv2
import PIL
import json, os
import sys
import labelme
import labelme.utils as utils
import glob
import itertools
class GAN():
def __init__(self, #定义全局变量
):
self.img_shape = (28, 28, 1)
self.save_path = r'C:\Users\Administrator\Desktop\photo\GAN.h5'
self.img_path = r'C:\Users\Administrator\Desktop\photo'
self.batch_size = 20
self.latent_dim = 100
self.sample_interval=1
self.epoch=100
#建立GAN模型的方法
self.generator_model = self.build_generator()
self.discriminator_model = self.build_discriminator()
self.model = self.bulid_model()
def build_generator(self):#生成器
input=keras.Input(shape=self.latent_dim)
x=layers.Dense(256)(input)
x=layers.LeakyReLU(alpha=0.2)(x)
x=layers.BatchNormalization(momentum=0.8)(x)
x = layers.Dense(512)(x)
x = layers.LeakyReLU(alpha=0.2)(x)
x = layers.BatchNormalization(momentum=0.8)(x)
x = layers.Dense(1024)(x)
x = layers.LeakyReLU(alpha=0.2)(x)
x = layers.BatchNormalization(momentum=0.8)(x)
x=layers.Dense(np.prod(self.img_shape),activation='sigmoid')(x)
output=layers.Reshape(self.img_shape)(x)
model=keras.Model(inputs=input,outputs=output,name='generator')
model.summary()
return model
def build_discriminator(self):#判别器
input=keras.Input(shape=self.img_shape)
x=layers.Flatten(input_shape=self.img_shape)(input)
x=layers.Dense(512)(x)
x=layers.LeakyReLU(alpha=0.2)(x)
x=layers.Dense(256)(x)
x=layers.LeakyReLU(alpha=0.2)(x)
output=layers.Dense(1,activation='sigmoid')(x)
model=keras.Model(inputs=input,outputs=output,name='discriminator')
model.summary()
return model
def bulid_model(self):#建立GAN模型
self.discriminator_model.compile(loss='binary_crossentropy',
optimizer=keras.optimizers.Adam(0.0001, 0.000001),
metrics=['accuracy'])
self.discriminator_model.trainable = False#使生成器不训练
inputs = keras.Input(shape=self.latent_dim)
img = self.generator_model(inputs)
outputs = self.discriminator_model(img)
model = keras.Model(inputs=inputs, outputs=outputs)
model.summary()
model.compile(optimizer=keras.optimizers.Adam(0.0001, 0.000001),
loss='binary_crossentropy',
)
return model
def load_data(self):
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()
train_images = train_images /255
train_images = np.expand_dims(train_images, axis=3)
print('img_number:',train_images.shape)
return train_images
def train(self):
train_images=self.load_data()#读取数据
#生成标签
valid = np.ones((self.batch_size, 1))
fake = np.zeros((self.batch_size, 1))
step=int(train_images.shape[0]/self.batch_size)#计算步长
print('step:',step)
for epoch in range(self.epoch):
train_images = (tf.random.shuffle(train_images)).numpy()#每个epoch打乱一次
if epoch % self.sample_interval == 0:
self.generate_sample_images(epoch)
for i in range(step):
idx = np.arange(i*self.batch_size,i*self.batch_size+self.batch_size,1)#生成索引
imgs =train_images[idx]#读取索引对应的图片
noise = np.random.normal(0, 1, (self.batch_size, 100)) # 生成标准的高斯分布噪声
gan_imgs = self.generator_model.predict(noise)#通过噪声生成图片
#----------------------------------------------训练判别器
discriminator_loss_real = self.discriminator_model.train_on_batch(imgs, valid) # 真实数据对应标签1
discriminator_loss_fake = self.discriminator_model.train_on_batch(gan_imgs, fake) # 生成的数据对应标签0
discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)
#----------------------------------------------- 训练生成器
noise = np.random.normal(0, 1, (self.batch_size, 100))
generator_loss = self.model.train_on_batch(noise, valid)
if i%10==0:#每十步进行输出
print("epoch:%d step:%d [discriminator_loss: %f, acc: %.2f%%] [generator_loss: %f]" % (
epoch,i,discriminator_loss[0], 100 * discriminator_loss[1], generator_loss))
self.model.save(self.save_path)#存储模型
def generate_sample_images(self, epoch):#生成图片
row, col = 5, 5#行列的数字
noise = np.random.normal(0, 1, (row * col, self.latent_dim))#生成噪声
gan_imgs = self.generator_model.predict(noise)
fig, axs = plt.subplots(row, col)#生成5*5的画板
idx = 0
for i in range(row):
for j in range(col):
axs[i, j].imshow(gan_imgs[idx, :, :, 0], cmap='gray')
axs[i, j].axis('off')
idx += 1
fig.savefig(self.img_path+"/%d.png" % epoch)
plt.close()#关闭画板
def pred(self):#载入模型并生成图片
model=keras.models.load_model(self.save_path)
model.summary()
noise = np.random.normal(0, 1, (1, self.latent_dim))
generator=keras.Model(inputs=model.layers[1].input,outputs=model.layers[1].output)
generator.summary()
img=np.squeeze(generator.predict([noise]))
plt.imshow(img)
plt.show()
print(img.shape)
if __name__ == '__main__':
GAN = GAN()
GAN.train()
三、实验现象
可以看到网络生成的数字越来越逼真了