CGAN論文詳解與代碼詳解

from: https://mp.weixin.qq.com/s?__biz=Mzg5NTEwNjQ3Ng==&mid=2247484324&idx=1&sn=57f25bc89164b8ef6ad1234f8bcfb49b&chksm=c0142f67f763a6718bb1f4bf35dced93cecdac0f128839f67e7e5aacd8ea7aeea9296193731a&mpshare=1&scene=1&srcid=1108VCtC1j9NzHgSvKVUcFBr&sharer_sharetime=1573207840909&sharer_shareid=bcd65ed3e6918861746f708528ff9349&key=0dd8ea2ddf54b7b6bc6891811a31087e753a1d82b9c5a3ecd66f1122575f70364473cea7635e03ea3707e0d794f8061c0c6342fbdf50271d49c2d50958171e5428f012e61a56499a7dbae611d569bd47&ascene=1&uin=MjI0NDQyMzUwMA%3D%3D&devicetype=Windows+10&version=62070158&lang=en&pass_ticket=LYK2FsN7veBklpbeKcHMb1K3blwAGAaoIsKPfv%2FLYc7XPgx5%2FHhkB0YEcx6fmTZh

作者:戴璞微
作者微信:dpw19960712
作者QQ:771830171
微信公衆號:AI那點小事
知乎專欄:AI那點小事
CSDN博客:https://daipuweiai.blog.csdn.net/
CSDN博客原文鏈接:https://daipuweiai.blog.csdn.net/article/details/102962215
知乎專欄原文鏈接:https://zhuanlan.zhihu.com/p/90835081

前言

自從10月15號在廣州的實習結束後,這將近1個月的時間由於學校各種實習相關手續、答辯和趕上畢業論文開題的節奏等原因,因此相關實習結束之前相關筆記沒有及時。從今天開始,將恢復相關博客的更新。

在之前我們介紹了DCGAN與原始GAN的相關理論,並給出了DCGAN生成手寫數字圖像的代碼。若有興趣請分別移步如下鏈接:

  1. 【GAN】一、利用keras實現DCGAN生成手寫數字圖像

  2. 【GAN】二、原始GAN論文詳解

  3. 【GAN】三、DCGAN論文詳解

本篇博客我們將介紹CGAN(條件GAN)論文的相關細節。CGAN的論文網址請移步:https://arxiv.org/pdf/1411.1784.pdf。CGAN生成手寫數字的keras代碼請移步:https://github.com/Daipuwei/CGAN-mnist。

一、 GAN回顧

爲了兼顧CGAN的相關理論介紹,我們首先回顧GAN相關細節。GAN主要包括兩個網絡,一個是生成器和判別器,生成器的目的就是將隨機輸入的高斯噪聲映射成圖像(“假圖”),判別器則是判斷輸入圖像是否來自生成器的概率,即判斷輸入圖像是否爲假圖的概率。

在這裏我們假設數據爲,生成器的數據分佈爲,噪聲分佈爲,那麼噪聲的結果可以記,數據在判別器上的結果爲

那麼GAN的目的就是無中生有,以假亂真。即要使得生成器生成的所謂的"假圖"騙過判別器,那麼最優狀態就是生成器生成的所謂的"假圖"在判別器的判別結果爲0.5,不知道到底是真圖還是假圖。GAN的目標函數如下:


二、CGAN網絡架構詳解

在介紹CGAN的原理接下來介紹了CGAN的相關原理。原始的GAN的生成器只能根據隨機噪聲進行生成圖像,至於這個圖像是什麼(即標籤是什麼我們無從得知),判別器也只能接收圖像輸入進行判別是否圖像來使生成器。因此CGAN的主要貢獻就是在原始GAN的生成器與判別器中的輸入中加入額外信息。額外信可以是任何信息,例如標籤。因此CGAN的提出使得GAN可以利用圖像與對應的標籤進行訓練,並在測試階段 利用給定標籤生成特定圖像。

在CGAN的論文中,網絡架構使用的MLP(全連接網絡)。在CGAN中的生成器,我們給定一個輸入噪聲0和額外信息,之後將兩者通過全連接層連接到一起作爲隱藏層輸入。同樣地,在判別器中輸入圖像和 額外信息也將連接到一起作爲隱藏層輸入。CGAN的網絡架構圖如下所示:


那麼,CGAN的目標函數可以表述成如下形式:

下面是CGAN論文中生成的手寫數字圖像的結果,每一行代表有一個標籤,例如第一行代表標籤爲0的圖片。

 


三、CGAN-MNIST代碼詳解

接下來我們將主要介紹CGAN生成手寫數字圖像的keras代碼。github鏈接爲:https://github.com/Daipuwei/CGAN-mnist。首先給出CGAN的網絡架構代碼:

# -*- coding: utf-8 -*-
# @Time    : 2019/10/8 13:39
# @Author  : Dai PuWei
# @File    : CGAN.py
# @Software: PyCharm

import os
import cv2
import numpy as np
import datetime
import matplotlib.pyplot as plt

from scipy.stats import truncnorm


from keras import Input
from keras import Model
from keras import Sequential

from keras.layers import Dense
from keras.layers import Activation
from keras.layers import Reshape
from keras.layers import Conv2DTranspose
from keras.layers import BatchNormalization
from keras.layers import Conv2D
from keras.layers import LeakyReLU
from keras.layers import Dropout
from keras.layers import Flatten
from keras.layers.merge import multiply
from keras.layers.merge import concatenate
from keras.layers.merge import add
from keras.layers import Embedding
from keras.utils import to_categorical
from keras.optimizers import Adam
from keras.utils.generic_utils import Progbar
from copy import deepcopy
from keras.datasets import mnist

def make_trainable(net, val):
    """ Freeze or unfreeze layers
    """
    net.trainable = val
    for l in net.layers: l.trainable = val

class CGAN(object):

    def __init__(self,config,weight_path=None):
        """
        這是CGAN的初始化函數
        :param config: 參數配置類實例
        :param weight_path: 權重文件地址,默認爲None
        """
        self.config = config
        self.build_cgan_model()

        if weight_path is not None:
            self.cgan.load_weights(weight_path,by_name=True)

    def build_cgan_model(self):
        """
        這是搭建CGAN模型的函數
        :return:
        """
        # 初始化輸入
        self.generator_noise_input = Input(shape=(self.config.generator_noise_input_dim,))
        self.condational_label_input = Input(shape=(1,), dtype='int32')
        self.discriminator_image_input = Input(shape=self.config.discriminator_image_input_dim)

        # 定義優化器
        self.optimizer = Adam(lr=2e-4, beta_1=0.5)

        # 構建生成器模型與判別器模型
        self.discriminator_model = self.build_discriminator_model()
        self.discriminator_model.compile(optimizer=self.optimizer, loss=['binary_crossentropy'],metrics=['accuracy'])
        self.generator_model = self.build_generator()

        # 構建CGAN模型
        self.discriminator_model.trainable = False
        self.cgan_input = [self.generator_noise_input,self.condational_label_input]
        generator_output = self.generator_model(self.cgan_input)
        cgan_output = self.discriminator_model([generator_output,self.condational_label_input])
        self.cgan = Model(self.cgan_input,cgan_output)

        # 編譯
        #self.discriminator_model.compile(optimizer=self.optimizer,loss='binary_crossentropy')
        self.cgan.compile(optimizer=self.optimizer,loss=['binary_crossentropy'])

    def build_discriminator_model(self):
        """
        這是搭建生成器模型的函數
        :return:
        """
        model = Sequential()

        model.add(Dense(512, input_dim=np.prod(self.config.discriminator_image_input_dim)))
        model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha))
        model.add(Dropout(self.config.LeakyReLU_alpha))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha))
        model.add(Dropout(self.config.LeakyReLU_alpha))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        img = Input(shape=self.config.discriminator_image_input_dim)
        label = Input(shape=(1,), dtype='int32')

        label_embedding = Flatten()(Embedding(self.config.condational_label_num,
                                              np.prod(self.config.discriminator_image_input_dim))(label))
        flat_img = Flatten()(img)
        model_input = multiply([flat_img, label_embedding])
        validity = model(model_input)

        return Model([img, label], validity)


    def build_generator(self):
        """
        這是構建生成器網絡的函數
        :return:返回生成器模型generotor_model
        """
        model = Sequential()

        model.add(Dense(256, input_dim=self.config.generator_noise_input_dim))
        model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha))
        model.add(BatchNormalization(momentum=self.config.batchnormalization_momentum))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha))
        model.add(BatchNormalization(momentum=self.config.batchnormalization_momentum))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha))
        model.add(BatchNormalization(momentum=self.config.batchnormalization_momentum))
        model.add(Dense(np.prod(self.config.discriminator_image_input_dim), activation='tanh'))
        model.add(Reshape(self.config.discriminator_image_input_dim))

        model.summary()

        noise = Input(shape=(self.config.generator_noise_input_dim,))
        label = Input(shape=(1,), dtype='int32')
        label_embedding = Flatten()(Embedding(self.config.condational_label_num, self.config.generator_noise_input_dim)(label))

        model_input = multiply([noise, label_embedding])
        img = model(model_input)

        return Model([noise, label], img)

    def train(self, train_datagen, epoch, k, batch_size=256):
        """
        這是DCGAN的訓練函數
        :param train_generator:訓練數據生成器
        :param epoch:週期數
        :param batch_size:小批量樣本規模
        :param k:訓練判別器次數
        :return:
        """
        time =datetime.datetime.now().strftime("%Y%m%d%H%M%S")
        model_path = os.path.join(self.config.model_dir,time)
        if not os.path.exists(model_path):
            os.mkdir(model_path)

        train_result_path = os.path.join(self.config.train_result_dir,time)
        if not os.path.exists(train_result_path):
            os.mkdir(train_result_path)

        for ep in np.arange(1, epoch+1).astype(np.int32):
            cgan_losses = []
            d_losses = []
            # 生成進度條
            length = train_datagen.batch_num
            progbar = Progbar(length)
            print('Epoch {}/{}'.format(ep, epoch))
            iter = 0
            while True:
                # 遍歷一次全部數據集,那麼重新來結束while循環
                #print("iter:{},{}".format(iter,train_datagen.get_epoch() != ep))
                if train_datagen.epoch != ep:
                    break

                # 獲取真實圖片,並構造真圖對應的標籤
                batch_real_images, batch_real_labels = train_datagen.next_batch()
                batch_real_num_labels = np.ones((batch_size, 1))
                #batch_real_num_labels = truncnorm.rvs(0.7, 1.2, size=(batch_size, 1))
                # 初始化隨機噪聲,僞造假圖,併合並真圖和假圖數據集
                batch_noises = np.random.normal(0, 1, size = (batch_size, self.config.generator_noise_input_dim))
                d_loss = []
                for i in np.arange(k):
                    # 構造假圖標籤,合併真圖和假圖對應標籤
                    batch_fake_num_labels = np.zeros((batch_size,1))
                    #batch_fake_num_labels = truncnorm.rvs(0.0, 0.3, size=(batch_size, 1))
                    batch_fake_labels = deepcopy(batch_real_labels)
                    batch_fake_images = self.generator_model.predict([batch_noises,batch_fake_labels])

                    # 訓練判別器
                    real_d_loss = self.discriminator_model.train_on_batch([batch_real_images,batch_real_labels],
                                                                                      batch_real_num_labels)
                    fake_d_loss = self.discriminator_model.train_on_batch([batch_fake_images, batch_fake_labels],
                                                                                      batch_fake_num_labels)
                    d_loss.append(list(0.5*np.add(real_d_loss,fake_d_loss)))
                #print(d_loss)
                d_losses.append(list(np.average(d_loss,0)))
                #print(d_losses)

                # 生成一個batch_size的噪聲來訓練生成器
                #batch_num_labels = truncnorm.rvs(0.7, 1.2, size=(batch_size, 1))
                batch_num_labels = np.ones((batch_size,1))
                batch_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1)
                cgan_loss = self.cgan.train_on_batch([batch_noises,batch_labels], batch_num_labels)
                cgan_losses.append(cgan_loss)

                # 更新進度條
                progbar.update(iter, [('dcgan_loss', cgan_losses[iter]),
                                      ('discriminator_loss',d_losses[iter][0]),
                                      ('acc',d_losses[iter][1])])
                #print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (ep, d_losses[ep][0], 100 * d_losses[ep][1],cgan_loss))
                iter += 1
            if ep % self.config.save_epoch_interval == 0:
                model_cgan = "Epoch{}dcgan_loss{}discriminator_loss{}acc{}.h5".format(ep, np.average(cgan_losses),
                                                                                      np.average(d_losses,0)[0],np.average(d_losses,0)[1])
                self.cgan.save(os.path.join(model_path, model_cgan))
                save_dir = os.path.join(train_result_path, str("Epoch{}".format(ep)))
                if not os.path.exists(save_dir):
                    os.mkdir(save_dir)
                self.save_image(int(ep), save_dir)
            '''
            if int(ep) in self.config.generate_image_interval:
                save_dir = os.path.join(train_result_path,str("Epoch{}".format(ep)))
                if not os.path.exists(save_dir):
                    os.mkdir(save_dir)
                self.save_image(ep,save_dir)
            '''
        plt.plot(np.arange(epoch),cgan_losses,'b-','cgan-loss')
        plt.plot(np.arange(epoch), d_losses[0], 'b-', 'd-loss')
        plt.grid(True)
        plt.legend(locs="best")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.savefig(os.path.join(train_result_path,"loss.png"))

    def save_image(self, epoch,save_path):
        """
        這是保存生成圖片的函數
        :param epoch:週期數
        :param save_path: 圖片保存地址
        :return:
        """
        rows, cols = 10, 10

        fig, axs = plt.subplots(rows, cols)
        for i in range(rows):
            label = np.array([i]*rows).astype(np.int32).reshape(-1,1)
            noise = np.random.normal(0, 1, (cols, 100))
            images = self.generator_model.predict([noise,label])
            images = 127.5*images+127.5
            cnt = 0
            for j in range(cols):
                #img_path = os.path.join(save_path, str(cnt) + ".png")
                #cv2.imwrite(img_path, images[cnt])
                #axs[i, j].imshow(image.astype(np.int32)[:,:,0])
                axs[i, j].imshow(images[cnt,:, :, 0].astype(np.int32), cmap='gray')
                axs[i, j].axis('off')
                cnt += 1
        fig.savefig(os.path.join(save_path, "mnist-{}.png".format(epoch)), dpi=600)
        plt.close()

    def generate_image(self,label):
        """
        這是僞造一張圖片的函數
        :param label:標籤
        """
        noise = truncnorm.rvs(-1, 1, size=(1, self.config.generator_noise_input_dim))
        label = np.array([label]).T
        image = self.generator_model.predict([noise,label])[0]
        image = 127.5*(image+1)
        return image

爲了訓練我們必須還構造一個數據集迭代器來讀取小批量手寫數字圖像數據,數據集迭代器類的代碼如下:

# -*- coding: utf-8 -*-
# @Time    : 2019/10/8 17:29
# @Author  : Dai PuWei
# @File    : MnistGenerator.py
# @Software: PyCharm

import math
import numpy as np
from keras.datasets import mnist

class MnistGenerator(object):

    def __init__(self,batch_size):
        """
        這是圖像數據生成器的初始化函數
        :param batch_size: 小批量樣本規模
        """
        (x_train,y_train),(x_test,y_test) = mnist.load_data()
        #self.x = np.concatenate([x_train,x_test]).astype(np.float32)
        self.x = np.expand_dims((x_train.astype(np.float32)-127.5)/127.5,axis=-1)
        #self.y = to_categorical(np.concatenate([y_train,y_test]),num_classes=10)
        self.y = y_train.reshape(-1,1)
        #self.y = self.y[y == ]
        #print(np.shape(self.x))
        #print(np.shape(self.y))
        self.images_size = len(self.x)
        random_index = np.random.permutation(np.arange(self.images_size))
        self.x = self.x[random_index]
        self.y = self.y[random_index]

        self.epoch = 1                                  # 當前迭代次數
        self.batch_size = int(batch_size)
        self.batch_num = math.ceil(self.images_size / self.batch_size)
        self.start = 0
        self.end = 0
        self.finish_flag = False                        # 數據集是否遍歷完一次標誌

    def _next_batch(self):
        """
        :return:
        """
        while True:
            #batch_images = np.array([])
            #batch_labels = np.array([])
            if self.finish_flag:  # 數據集遍歷完一次
                random_index = np.random.permutation(np.arange(self.images_size))
                self.x = self.x[random_index]
                self.y = self.y[random_index]
                self.finish_flag = False
                self.epoch += 1
            self.end = int(np.min([self.images_size,self.start+self.batch_size]))
            batch_images = self.x[self.start:self.end]
            batch_labels = self.y[self.start:self.end]
            batch_size = self.end - self.start
            if self.end == self.images_size:            # 數據集剛分均分
                self.finish_flag = True
            if batch_size < self.batch_size:        # 小批次規模小於與預定規模,基本上是最後一組
                random_index = np.random.permutation(np.arange(self.images_size))
                self.x = self.x[random_index]
                self.y = self.y[random_index]
                batch_images = np.concatenate((batch_images, self.x[0:self.batch_size - batch_size]))
                batch_labels = np.concatenate((batch_labels, self.y[0:self.batch_size - batch_size]))
                self.start = self.batch_size - batch_size
                self.epoch += 1
            else:
                self.start = self.end
            yield batch_images,batch_labels

    def next_batch(self):
        datagen = self._next_batch()
        return datagen.__next__()

下面是相關訓練CGAN的代碼:

# -*- coding: utf-8 -*-
# @Time    : 2019/10/8 15:43
# @Author  : Dai PuWei
# @File    : train.py
# @Software: PyCharm

import os
import datetime

from CGAN.CGAN import CGAN
from Config.Config import MnistConfig
from DataGenerator.MnistGenerator import MnistGenerator

def run_main():
    """
    這是主函數
    """
    cfg =  MnistConfig()
    cgan = CGAN(cfg)
    batch_size = 512
    #train_datagen = Cifar10Generator(int(batch_size/2))
    train_datagen = MnistGenerator(batch_size)
    cgan.train(train_datagen,100000,1,batch_size)


if __name__ == '__main__':
    run_main()

下面是訓練過程中的CGAN的生成的手寫數字圖像。第1個epoch之後的生成結果:

在這裏插入圖片描述


第10個epoch之後的生成結果:

在這裏插入圖片描述


第100個epoch之後的生成結果:

在這裏插入圖片描述


第1000個epoch之後的生成結果:

在這裏插入圖片描述


下面是CGAN的測試代碼:

 

# -*- coding: utf-8 -*-
# @Time    : 2019/11/8 13:11
# @Author  : DaiPuWei
# @Email   : [email protected]
# @File    : test.py
# @Software: PyCharm


import os
from CGAN.CGAN import CGAN
from Config.Config import MnistConfig

def run_main():
    """
    這是主函數
    """
    weight_path = os.path.abspath("./model/20191009134644/Epoch1378dcgan_loss1.5952800512313843discriminator_loss[0.49839333 0.7379193 ]acc[0.49839333 0.7379193 ].h5")
    result_path = os.path.abspath("./test_result")
    if not os.path.exists(result_path):
        os.mkdir(result_path)
    cfg =  MnistConfig()
    cgan = CGAN(cfg,weight_path)
    cgan.save_image(0,result_path)


if __name__ == '__main__':
    run_main()
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章