tensorflow keras 完整GAN网络代码(面向对象) 利用MNIST手写数据集生成手写数字

前言

鉴于网上对于GAN网络代码的结构不太好,甚至无法做到迭代,我决定就GAN网络来写一个拥有能够一目了然的完整结构的代码,以帮助那些和我一样刚开始接触这类网络的人,本篇中的GAN网络由全连接层组成,以此来复现最简单的GAN网络结构。

一、代码结构

代码由全局量、生成器、判别器、GAN网络、训练、范例图片生成以及载入模型生成图片这几个结构组成

二、代码

import numpy as np
import matplotlib
from matplotlib import pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers,optimizers,losses
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.python.keras import backend as K
from tensorflow.keras.utils import plot_model
from IPython.display import Image


import cv2
import PIL
import json, os
import sys

import labelme
import labelme.utils as utils
import glob
import itertools

class GAN():
    def __init__(self,      #定义全局变量
                 ):
        self.img_shape = (28, 28, 1)
        self.save_path = r'C:\Users\Administrator\Desktop\photo\GAN.h5'
        self.img_path = r'C:\Users\Administrator\Desktop\photo'
        self.batch_size = 20
        self.latent_dim = 100
        self.sample_interval=1
        self.epoch=100
        #建立GAN模型的方法
        self.generator_model = self.build_generator()
        self.discriminator_model = self.build_discriminator()
        self.model = self.bulid_model()


    def build_generator(self):#生成器

        input=keras.Input(shape=self.latent_dim)

        x=layers.Dense(256)(input)
        x=layers.LeakyReLU(alpha=0.2)(x)
        x=layers.BatchNormalization(momentum=0.8)(x)

        x = layers.Dense(512)(x)
        x = layers.LeakyReLU(alpha=0.2)(x)
        x = layers.BatchNormalization(momentum=0.8)(x)

        x = layers.Dense(1024)(x)
        x = layers.LeakyReLU(alpha=0.2)(x)
        x = layers.BatchNormalization(momentum=0.8)(x)

        x=layers.Dense(np.prod(self.img_shape),activation='sigmoid')(x)
        output=layers.Reshape(self.img_shape)(x)

        model=keras.Model(inputs=input,outputs=output,name='generator')
        model.summary()
        return model

    def build_discriminator(self):#判别器

        input=keras.Input(shape=self.img_shape)

        x=layers.Flatten(input_shape=self.img_shape)(input)
        x=layers.Dense(512)(x)
        x=layers.LeakyReLU(alpha=0.2)(x)
        x=layers.Dense(256)(x)
        x=layers.LeakyReLU(alpha=0.2)(x)
        output=layers.Dense(1,activation='sigmoid')(x)

        model=keras.Model(inputs=input,outputs=output,name='discriminator')
        model.summary()
        return model

    def bulid_model(self):#建立GAN模型
        self.discriminator_model.compile(loss='binary_crossentropy',
                                    optimizer=keras.optimizers.Adam(0.0001, 0.000001),
                                    metrics=['accuracy'])

        self.discriminator_model.trainable = False#使生成器不训练

        inputs = keras.Input(shape=self.latent_dim)
        img = self.generator_model(inputs)
        outputs = self.discriminator_model(img)
        model = keras.Model(inputs=inputs, outputs=outputs)
        model.summary()
        model.compile(optimizer=keras.optimizers.Adam(0.0001, 0.000001),
                      loss='binary_crossentropy',
                      )
        return model

    def load_data(self):
        (train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()
        train_images = train_images /255
        train_images = np.expand_dims(train_images, axis=3)
        print('img_number:',train_images.shape)
        return train_images

    def train(self):
        train_images=self.load_data()#读取数据

        #生成标签
        valid = np.ones((self.batch_size, 1))
        fake = np.zeros((self.batch_size, 1))

        step=int(train_images.shape[0]/self.batch_size)#计算步长
        print('step:',step)

        for epoch in range(self.epoch):
            train_images = (tf.random.shuffle(train_images)).numpy()#每个epoch打乱一次
            if epoch % self.sample_interval == 0:
                self.generate_sample_images(epoch)

            for i in range(step):

                idx = np.arange(i*self.batch_size,i*self.batch_size+self.batch_size,1)#生成索引
                imgs =train_images[idx]#读取索引对应的图片
                noise = np.random.normal(0, 1, (self.batch_size, 100))  # 生成标准的高斯分布噪声
                gan_imgs = self.generator_model.predict(noise)#通过噪声生成图片
                #----------------------------------------------训练判别器
                discriminator_loss_real = self.discriminator_model.train_on_batch(imgs, valid)  # 真实数据对应标签1
                discriminator_loss_fake = self.discriminator_model.train_on_batch(gan_imgs, fake)  # 生成的数据对应标签0
                discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)
                #----------------------------------------------- 训练生成器
                noise = np.random.normal(0, 1, (self.batch_size, 100))
                generator_loss = self.model.train_on_batch(noise, valid)
                if i%10==0:#每十步进行输出
                    print("epoch:%d step:%d [discriminator_loss: %f, acc: %.2f%%] [generator_loss: %f]" % (
                        epoch,i,discriminator_loss[0], 100 * discriminator_loss[1], generator_loss))



        self.model.save(self.save_path)#存储模型

    def generate_sample_images(self, epoch):#生成图片

        row, col = 5, 5#行列的数字
        noise = np.random.normal(0, 1, (row * col, self.latent_dim))#生成噪声
        gan_imgs = self.generator_model.predict(noise)
        fig, axs = plt.subplots(row, col)#生成5*5的画板
        idx = 0

        for i in range(row):
            for j in range(col):
                axs[i, j].imshow(gan_imgs[idx, :, :, 0], cmap='gray')
                axs[i, j].axis('off')
                idx += 1
        fig.savefig(self.img_path+"/%d.png" % epoch)
        plt.close()#关闭画板

    def pred(self):#载入模型并生成图片
        model=keras.models.load_model(self.save_path)
        model.summary()
        noise = np.random.normal(0, 1, (1, self.latent_dim))

        generator=keras.Model(inputs=model.layers[1].input,outputs=model.layers[1].output)
        generator.summary()
        img=np.squeeze(generator.predict([noise]))
        plt.imshow(img)
        plt.show()
        print(img.shape)



if __name__ == '__main__':
    GAN = GAN()
    GAN.train()


三、实验现象

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
可以看到网络生成的数字越来越逼真了

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章