GAN網絡之入門教程(四)之基於DCGAN動漫頭像生成

這一篇博客以代碼爲主,主要是來介紹如果使用keras構建一個DCGAN,然後基於DCGAN,做一個自動生成動漫頭像。訓練過程如下(50輪的訓練過程)“

關於DCGAN或者GAN的相關知識,可以參考GAN網絡入門教程。建議先了解相關知識,再來看這一篇博客。

項目地址:GitHub

使用前準備

首先的首先,我們肯定是需要數據集的,這裏使用的數據集來自kaggle——Anime Faces。裏面有21551張動漫頭像的圖片。大家可以到kaggle上面去下載數據集,或者說到我的github上去下載數據集(求個 ⭐ 不過分吧)。部分數據如下:

如果自己電腦計算機資源不是很強的話,比如我,一個mx250小水管(玩玩lol還是可以的,訓練這個模型可能要等到下輩子),推薦大家去註冊一個kaggle或者colab賬號去白嫖GPU資源(1080,2080的玩家請隨意)。不過個人更加的推薦kaggle,因爲感覺它的資源分配是可見的,且可以後臺運行。

數據集

數據集是動漫圖片,我們可以將圖片的像素點的值變成\([-1,1]\)之間,具體代碼如下:

# 數據集的位置
avatar_img_path = "./data"

import imageio
import os
import numpy as np
def load_data():
    """
    加載數據集
    :return: 返回numpy數組
    """
    all_images = []
    for image_name in os.listdir(avatar_img_path):
        # 加載圖片
        image =  imageio.imread(os.path.join(avatar_img_path,image_name))
        all_images.append(image)
    all_images = np.array(all_images)
    # 將圖片數值變成[-1,1]
    all_images = (all_images - 127.5) / 127.5
    # 將數據隨機排序
    np.random.shuffle(all_images)
    return all_images
img_dataset = load_data()

然後定義展示圖片的方法:


import matplotlib.pyplot as plt
def show_images(images,index = -1):
    """
    展示並保存圖片
    :param images: 需要show的圖片
    :param index: 圖片名
    :return:
    """
    plt.figure()
    for i, image in enumerate(images):
        ax = plt.subplot(5, 5, i+1)
        plt.axis('off')
        plt.imshow(image)
    plt.savefig("data_%d.png"%index)
    plt.show()
  • 展示數據集中的部分圖片:
show_images(img_dataset[0: 25])

定義參數

這裏我們只定義兩個參數,圖片的shape代表生成的圖片是\(64 \times 64\)的RGB圖片,以及noise的大小是100:

# noise的維度
noise_dim = 100
# 圖片的shape
image_shape = (64,64,3)

構建網絡

首先導入tensorflow中的keras庫,如下:

from tensorflow.keras.models import Sequential,Model
from tensorflow.keras.layers import UpSampling2D, Conv2D, Dense, BatchNormalization, LeakyReLU, Input,Reshape, MaxPooling2D, Flatten, AveragePooling2D, Conv2DTranspose
from tensorflow.keras.optimizers import Adam

下圖中的網絡結構參照了kaggle中的Anime face generation with DCGAN (beginner)

構建G網絡

生成器網絡,我們按照如下的結構進行構建:

原理是我們通過全連接層將nosise的向量放大,然後在再使用反捲積等操作將其逐漸變成shape爲\((64,64,3)\)的圖片。

def build_G():
    """
    構建生成器
    :return:
    """
    model = Sequential()
    # 全連接層 100 -> 2048
    model.add(Dense(2048,input_dim = noise_dim))
    # 激活函數
    model.add(LeakyReLU(0.2))
    # 全連接層 2048 ->  8 * 8 * 256
    model.add(Dense(8 * 8 * 256))
    # DN層
    model.add(BatchNormalization())
    model.add(LeakyReLU(0.2))
    # 8 * 8 * 256 -> (8,8,256)
    model.add(Reshape((8, 8, 256)))
    # 卷積層 (8,8,256) -> (8,8,128)
    model.add(Conv2D(128, kernel_size=5, padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU(0.2))
    # 反捲積層 (8,8,128) -> (16,16,128)
    model.add(Conv2DTranspose(128, kernel_size=5, strides=2, padding='same'))
    model.add(LeakyReLU(0.2))
    # 反捲積層 (16,16,128) -> (32,32,64)
    model.add(Conv2DTranspose(64, kernel_size=5, strides=2, padding='same'))
    model.add(LeakyReLU(0.2))
    # 反捲積層  (32,32,64) -> (64,64,3) = 圖片
    model.add(Conv2DTranspose(3, kernel_size=5, strides=2, padding='same', activation='tanh'))
    return model
G = build_G()

可以發現,\(G\)網絡並沒有compile這一步,這是因爲\(G\)網絡的權重優化並不是直接優化的,而是通過GAN網絡進行間接優化的。

構建D網絡

D網絡的結構示意圖如下:

判別器網絡就是一個尋常的CNN網絡:


def build_D():
    """
    構建判別器
    :return: 
    """
    model = Sequential()
    # 卷積層
    model.add(Conv2D(64, kernel_size=5, padding='valid',input_shape = image_shape))
    # BN層
    model.add(BatchNormalization())
    # 激活層
    model.add(LeakyReLU(0.2))
    # 平均池化層
    model.add(AveragePooling2D(pool_size=2))
    # 卷積層
    model.add(Conv2D(128, kernel_size=3, padding='valid'))
    model.add(BatchNormalization())
    model.add(LeakyReLU(0.2))
    model.add(AveragePooling2D(pool_size=2))
    model.add(Conv2D(256, kernel_size=3, padding='valid'))
    model.add(BatchNormalization())
    model.add(LeakyReLU(0.2))
    model.add(AveragePooling2D(pool_size=2))
    # 將輸入展平
    model.add(Flatten())
    # 全連接層
    model.add(Dense(1024))
    model.add(BatchNormalization())
    model.add(LeakyReLU(0.2))
    # 最終輸出1(true img) 0(fake img)的概率大小
    model.add(Dense(1, activation='sigmoid'))
    model.compile(loss='binary_crossentropy',
              optimizer=Adam(learning_rate=0.0002, beta_1=0.5))
    return model
D = build_D()

構建GAN網絡

由前面的博客,我們知道,GAN網絡由G網絡和D網絡組成,GAN網絡的input爲nosie,輸出爲圖片真假的概率。因此它的網絡結構示意圖如下所示:


def build_gan():
    """
    構建GAN網絡
    :return:
    """
    # 冷凍判別器,也就是在訓練的時候只優化G的網絡權重,而對D保持不變
    D.trainable = False
    # GAN網絡的輸入
    gan_input = Input(shape=(noise_dim,))
    # GAN網絡的輸出
    gan_out = D(G(gan_input))
    # 構建網絡
    gan = Model(gan_input,gan_out)
    # 編譯GAN網絡,使用Adam優化器,以及加上交叉熵損失函數(一般用於二分類)
    gan.compile(loss='binary_crossentropy',optimizer=Adam(learning_rate=0.0002, beta_1=0.5))
    return gan
GAN = build_gan()

關於GAN的小trick

我們會將真實的圖片的lable標記爲1,fake圖片的lable標記爲0,但是我們訓練的時候可以使lable的值在一定的範圍內浮動。關於更多的trick,可以參考這篇[GANs training tricks](https://zhuanlan.zhihu.com/p/76717276)。


def sample_noise(batch_size):
    """
    隨機產生正態分佈(0,1)的noise
    :param batch_size:
    :return: 返回的shape爲(batch_size,noise)
    """
    return np.random.normal(size=(batch_size, noise_dim))

def smooth_pos_labels(y):
    """
    使得true label的值的範圍爲[0.7,1.2]
    :param y:
    :return:
    """
    return y - 0.3 + (np.random.random(y.shape) * 0.5)

def smooth_neg_labels(y):
    """
    使得fake label的值的範圍爲[0.0,0.3]
    :param y:
    :return:
    """
    return y + np.random.random(y.shape) * 0.3

訓練

開始訓練之前,我們還介紹一個函數,load_batch,因爲我們訓練圖片不可能說一次將圖片全部進行訓練而是分批次進行訓練(full batch需要大量的內存空間),而load_batch函數就行按批次加載圖片。

def load_batch(data, batch_size,index):
    """
    按批次加載圖片
    :param data: 圖片數據集
    :param batch_size: 批次大小
    :param index: 批次序號
    :return:
    """
    return data[index*batch_size: (index+1)*batch_size]

然後我們就需要定義\(train\)函數了:


def train(epochs=100, batch_size=64):
    """
    訓練函數
    :param epochs: 訓練的次數
    :param batch_size: 批尺寸
    :return:
    """
    # 判別器損失
    discriminator_loss = 0
    # 生成器損失
    generator_loss = 0
    # img_dataset.shape[0] / batch_size 代表這個數據可以分爲幾個批次進行訓練
    n_batches = int(img_dataset.shape[0] / batch_size)
    
    for i in range(epochs):
        for index in range(n_batches):
            # 按批次加載數據
            x = load_batch(img_dataset, batch_size,index)
            # 產生noise
            noise = sample_noise(batch_size)
            # G網絡產生圖片
            generated_images = G.predict(noise)
            # 產生爲1的標籤
            y_real = np.ones(batch_size)
            # 將1標籤的範圍變成[0.7 , 1.2]
            y_real = smooth_pos_labels(y_real)
            # 產生爲0的標籤
            y_fake = np.zeros(batch_size)
            # 將0標籤的範圍變成[0.0 , 0.3]
            y_fake = smooth_neg_labels(y_fake)
            # 訓練真圖片loss
            d_loss_real = D.train_on_batch(x, y_real)
            # 訓練假圖片loss
            d_loss_fake = D.train_on_batch(generated_images, y_fake)

            discriminator_loss = d_loss_real + d_loss_fake
            # 產生爲1的標籤
            y_real = np.ones(batch_size)
            # 訓練GAN網絡,input = fake_img ,label = 1
            generator_loss = GAN.train_on_batch(noise, y_real)
        
        print('[Epoch {0}]. Discriminator loss : {1}. Generator_loss: {2}.'.format(i, discriminator_loss, generator_loss))
        # 隨機產生(25,100)的noise
        test_noise = sample_noise(25)
        # 使用G網絡生成25張圖偏
        test_images = G.predict(test_noise)
        # show 預測 img
        show_images(test_images,i)

開始訓練:

train(epochs=500, batch_size=32)

最後就進入到了漫長的等待結果的時間了。

總結

項目地址:GitHub

參考

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章