VAE(變分自編碼器)學習筆記

VAE學習筆記

  1. 普通的編碼器可以將圖像這類信息編碼成爲特徵向量.

  2. 但通常這些特徵向量不具有空間上的連續性.

  3. VAE(變分自編碼器)可以將圖像信息編碼成爲具有空間連續性的特徵向量.

  4. 方法是向編碼器和解碼器中加入統計信息,即特徵向量代表的的是一個高斯分佈,強迫特徵向量服從高斯分佈.

  5. 編碼器是將圖片信息編碼成爲一個高斯分佈.

  6. 解碼器則是從特徵空間中進行採樣,再經過全連接層,反捲積層,卷積層等恢復成一張與輸入圖片大小相等的圖片.

  7. 損失函數有兩個目標:即(1)擬合原始圖片以及(2)使得特徵空間具有良好的結構及降低過擬合.因此,我們的損失函數由兩部分構成.其中第二部分需要使得編碼出的正態分佈圍繞在標準正態分佈周圍.


實現代碼

import keras
from keras import layers
from keras import backend as K
from keras.models import Model
import numpy as np

img_shape = (28,28,1)
batch_size = 16
latent_dim = 2

##### 圖片編碼器部分
input_img = keras.Input(shape = img_shape)

x = layers.Conv2D(32,3,padding = 'same', activation = 'relu')(input_img)
x = layers.Conv2D(64,3,padding = 'same', activation = 'relu',strides = (2,2))(x)
x = layers.Conv2D(64,3,padding = 'same', activation = 'relu')(x)
x = layers.Conv2D(64,3,padding = 'same', activation = 'relu')(x)

shape_before_flatting = K.int_shape(x)


x = layers.Flatten()(x)
x = layers.Dense(32,activation='relu')(x)

#輸入圖像最終被編碼爲如下兩個參數
z_mean = layers.Dense(latent_dim)(x)
z_log_var = layers.Dense(latent_dim)(x)
# 需要注意的是 z_log_var = 2log(sigma)
##### 編碼器部分結束

##### 採樣函數,用於在給定的正態分佈中進行採樣,這也就是編碼器加入統計信息的地方.
def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape = (K.shape(z_mean)[0],latent_dim),mean = 0.,stddev = 1.)
    return z_mean + K.exp(0.5*z_log_var) * epsilon


##### VAE解碼器部分
decoder_input = layers.Input(K.int_shape(z)[1:])
x = layers.Dense(np.prod(shape_before_flatting[1:]),activation='relu')(decoder_input)
x = layers.Reshape(shape_before_flatting[1:])(x)
x = layers.Conv2DTranspose(32,3,padding='same',activation = 'relu',strides = (2,2))(x)
x = layers.Conv2D(1,3,padding='same',activation='sigmoid')(x)
##### 到這裏x被恢復成爲一張圖片

##### 下面兩句話將編碼器和解碼器通過上採樣函數連接到了一起
decoder = Model(decoder_input,x)
z = layers.Lambda(sampling)([z_mean,z_log_var])
z_decoder = decoder(z)

###### 自定義損失函數層
def vae_loss(y_true,y_pred,e = 0.1):
    x = K.flatten(y_true)
    z_decoded = K.flatten(y_pred)
    xent_loss = keras.metrics.binary_crossentropy(x,z_decoded)
    kl_loss = -5e-4 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var),axis = -1)
    return K.mean(xent_loss + kl_loss)

from keras.datasets import mnist

vae = Model(input_img,z_decoder)
vae.compile(optimizer = 'rmsprop',loss = vae_loss)
vae.summary()


##### 訓練模型
from keras.datasets import mnist
(x_train,_),(x_test,y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_train = x_train.reshape(x_train.shape + (1,))
x_test = x_test.astype('float32') / 255.
x_test = x_test.reshape(x_test.shape + (1,))

vae.fit(x = x_train,y = x_train,
        shuffle = True,
        epochs = 10,
        batch_size = batch_size,
        validation_data = (x_test,x_test))

##### 在特徵空間進行連續採樣,觀察輸出圖片
import matplotlib.pyplot as plt
from scipy.stats import norm
n = 24
digit_size = 28
figure = np.zeros((digit_size*n,digit_size*n))
grid_x = norm.ppf(np.linspace(0.02,0.98,n))
grid_y = norm.ppf(np.linspace(0.02,0.98,n))

print(batch_size)

for i,yi in enumerate(grid_x):
    for j,xi in enumerate(grid_y):
        z_sample = np.array([[xi,yi]])
        z_sample = np.tile(z_sample,1).reshape(1, 2)
        x_decoded = decoder.predict(z_sample, batch_size = 1)
        digit = x_decoded[0].reshape(digit_size,digit_size)
        figure[i*digit_size:(i+1)*digit_size,j*digit_size:(j+1)*digit_size] = digit

plt.figure(figsize = (15,15))
plt.imshow(figure,cmap = 'Greys_r')
plt.show()

輸出結果

在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章