Python深度學習(10):VAE生成手寫數字

VAE簡介

自編碼器是接受一張圖像,通過編碼器將其映射到潛在向量空間,再通過解碼器將其解碼爲與圖像同樣大小的輸出。VAE是向自編碼器中添加了一些統計信息,迫使網絡學習連續的、高度結構化的潛在空間。
具體:
(1)將輸入圖像轉換爲潛在空間的z_mean和z_log_variance兩個參數。
(2)從z_means和z_log_variance所定義的潛在分佈中隨機採樣一個點
(3)使用解碼器模塊映射到原輸入圖像大小。

代碼

import keras
from keras import layers
from keras import backend as K
from keras.models import Model
import numpy as np
from keras.datasets import mnist
import matplotlib.pyplot as plt
from scipy.stats import norm

img_shape = (28, 28, 1) #輸入大小
batch_size = 16 #batch
latent_dim = 2 #潛在空間的維度

#搭建網絡
input_img = keras.Input(shape = img_shape)

x = layers.Conv2D(32, 3, padding='same', activation='relu')(input_img)
x = layers.Conv2D(64, 3, padding='same', activation='relu', strides=(2,2))(x)
x = layers.Conv2D(64, 3, padding='same', activation='relu')(x)
x = layers.Conv2D(64, 3, padding='same', activation='relu')(x)
shape_before_flattening = K.int_shape(x)

x = layers.Flatten()(x)
x = layers.Dense(32, activation='relu')(x)
#最終輸出結果z_mean和z_log_var
z_mean = layers.Dense(latent_dim)(x)
z_log_var = layers.Dense(latent_dim)(x)
#潛在隨機空間採樣
def sampling(args):
    z_means, z_log_var = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0, stddev=1,)
    return z_mean + K.exp(0.5*z_log_var) * epsilon

z = layers.Lambda(sampling)([z_mean, z_log_var]) #包裝到Lambda層

#解碼器實現

decoder_input = layers.Input(K.int_shape(z)[1:])
#對輸入進行上採樣
x =layers.Dense(np.prod(shape_before_flattening[1:]), activation='relu')(decoder_input)
#將x轉爲特徵圖,使其形狀與編碼器模型最後一個Flatten層之前的特徵圖的形狀相同
x = layers.Reshape(shape_before_flattening[1:])(x)
#將x解碼爲與原輸入圖像具有相同尺寸的特徵圖
x = layers.Conv2DTranspose(32, 3, padding='same', activation='relu', strides=(2,2))(x)
x = layers.Conv2D(1, 3, padding='same', activation='sigmoid')(x)
#將解碼器模型實例化
decoder = Model(decoder_input, x)
#將實例用於z
z_decoded = decoder(z)
#用於計算VAE損失
class CustomVariationalLayer(keras.layers.Layer):
    def vae_loss(self, x, z_decoded):
        x = K.flatten(x)
        z_decoded = K.flatten(z_decoded)
        xent_loss = keras.metrics.binary_crossentropy(x, z_decoded)
        k1_loss = -5e-4 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
        return K.mean(xent_loss + k1_loss)
    #編寫一個call方法實現自定義層
    def call(self, inputs):
        x = inputs[0]
        z_decoded = inputs[1]
        loss = self.vae_loss(x, z_decoded)
        self.add_loss(loss, inputs = inputs)
        return x

y = CustomVariationalLayer()([input_img, z_decoded])

vae = Model(input_img, y) #定義模型
vae.compile(optimizer='rmsprop', loss=None) #編譯模型
print(vae.summary())

(x_train, _), (x_test, y_test) = mnist.load_data() #加載數據

x_train = x_train.astype('float32') / 255. #對圖片除以255縮放
x_train = x_train.reshape(x_train.shape + (1,))
x_test = x_test.astype('float32') / 255.
x_test = x_test.reshape(x_test.shape + (1,))

vae.fit(x=x_train, y=None,
        shuffle = True,
        epochs = 10,
        batch_size = batch_size,
        validation_data = (x_test, None)) #訓練模型

n = 15
digit_size = 28
figure = np.zeros((digit_size*n, digit_size*n))
grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))

for i, yi in enumerate(grid_x):
    for j, xi in enumerate(grid_y):
        z_sample = np.array([xi, yi])
        z_sample = np.tile(z_sample, batch_size).reshape(batch_size, 2)
        x_decoded = decoder.predict(z_sample, batch_size=batch_size)
        digit = x_decoded[0].reshape(digit_size, digit_size)
        figure[i*digit_size : (i+1)*digit_size, j*digit_size : (j+1)*digit_size] = digit

plt.figure(figsize = (10, 10))
plt.imshow(figure, cmap='Greys_r')
plt.show()

運行結果

在這裏插入圖片描述

推薦閱讀

VAE直觀理解
李宏毅老師VAE部分筆記

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章