VAE(变分自编码器)学习笔记

VAE学习笔记

  1. 普通的编码器可以将图像这类信息编码成为特征向量.

  2. 但通常这些特征向量不具有空间上的连续性.

  3. VAE(变分自编码器)可以将图像信息编码成为具有空间连续性的特征向量.

  4. 方法是向编码器和解码器中加入统计信息,即特征向量代表的的是一个高斯分布,强迫特征向量服从高斯分布.

  5. 编码器是将图片信息编码成为一个高斯分布.

  6. 解码器则是从特征空间中进行采样,再经过全连接层,反卷积层,卷积层等恢复成一张与输入图片大小相等的图片.

  7. 损失函数有两个目标:即(1)拟合原始图片以及(2)使得特征空间具有良好的结构及降低过拟合.因此,我们的损失函数由两部分构成.其中第二部分需要使得编码出的正态分布围绕在标准正态分布周围.


实现代码

import keras
from keras import layers
from keras import backend as K
from keras.models import Model
import numpy as np

img_shape = (28,28,1)
batch_size = 16
latent_dim = 2

##### 图片编码器部分
input_img = keras.Input(shape = img_shape)

x = layers.Conv2D(32,3,padding = 'same', activation = 'relu')(input_img)
x = layers.Conv2D(64,3,padding = 'same', activation = 'relu',strides = (2,2))(x)
x = layers.Conv2D(64,3,padding = 'same', activation = 'relu')(x)
x = layers.Conv2D(64,3,padding = 'same', activation = 'relu')(x)

shape_before_flatting = K.int_shape(x)


x = layers.Flatten()(x)
x = layers.Dense(32,activation='relu')(x)

#输入图像最终被编码为如下两个参数
z_mean = layers.Dense(latent_dim)(x)
z_log_var = layers.Dense(latent_dim)(x)
# 需要注意的是 z_log_var = 2log(sigma)
##### 编码器部分结束

##### 采样函数,用于在给定的正态分布中进行采样,这也就是编码器加入统计信息的地方.
def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape = (K.shape(z_mean)[0],latent_dim),mean = 0.,stddev = 1.)
    return z_mean + K.exp(0.5*z_log_var) * epsilon


##### VAE解码器部分
decoder_input = layers.Input(K.int_shape(z)[1:])
x = layers.Dense(np.prod(shape_before_flatting[1:]),activation='relu')(decoder_input)
x = layers.Reshape(shape_before_flatting[1:])(x)
x = layers.Conv2DTranspose(32,3,padding='same',activation = 'relu',strides = (2,2))(x)
x = layers.Conv2D(1,3,padding='same',activation='sigmoid')(x)
##### 到这里x被恢复成为一张图片

##### 下面两句话将编码器和解码器通过上采样函数连接到了一起
decoder = Model(decoder_input,x)
z = layers.Lambda(sampling)([z_mean,z_log_var])
z_decoder = decoder(z)

###### 自定义损失函数层
def vae_loss(y_true,y_pred,e = 0.1):
    x = K.flatten(y_true)
    z_decoded = K.flatten(y_pred)
    xent_loss = keras.metrics.binary_crossentropy(x,z_decoded)
    kl_loss = -5e-4 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var),axis = -1)
    return K.mean(xent_loss + kl_loss)

from keras.datasets import mnist

vae = Model(input_img,z_decoder)
vae.compile(optimizer = 'rmsprop',loss = vae_loss)
vae.summary()


##### 训练模型
from keras.datasets import mnist
(x_train,_),(x_test,y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_train = x_train.reshape(x_train.shape + (1,))
x_test = x_test.astype('float32') / 255.
x_test = x_test.reshape(x_test.shape + (1,))

vae.fit(x = x_train,y = x_train,
        shuffle = True,
        epochs = 10,
        batch_size = batch_size,
        validation_data = (x_test,x_test))

##### 在特征空间进行连续采样,观察输出图片
import matplotlib.pyplot as plt
from scipy.stats import norm
n = 24
digit_size = 28
figure = np.zeros((digit_size*n,digit_size*n))
grid_x = norm.ppf(np.linspace(0.02,0.98,n))
grid_y = norm.ppf(np.linspace(0.02,0.98,n))

print(batch_size)

for i,yi in enumerate(grid_x):
    for j,xi in enumerate(grid_y):
        z_sample = np.array([[xi,yi]])
        z_sample = np.tile(z_sample,1).reshape(1, 2)
        x_decoded = decoder.predict(z_sample, batch_size = 1)
        digit = x_decoded[0].reshape(digit_size,digit_size)
        figure[i*digit_size:(i+1)*digit_size,j*digit_size:(j+1)*digit_size] = digit

plt.figure(figsize = (15,15))
plt.imshow(figure,cmap = 'Greys_r')
plt.show()

输出结果

在这里插入图片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章