變分自編碼器VAE的由來和簡單實現(PyTorch)

變分自編碼器VAE的由來和簡單實現(PyTorch)

​ 之前經常遇到變分自編碼器的概念(\(VAE\)),但是自己對於這個概念總是模模糊糊,今天就係統的對\(VAE\)進行一些整理和回顧。

VAE的由來

​ 假設有一個目標數據\(X=\{X_1,X_2,\cdots,X_n\}\),我們想生成一些數據,即生成\(\hat{X}=\{\hat{X_1},\hat{X_2},\cdots,\hat{X_n}\}\),其分佈與\(X\)相同。

​ 但是實際上,這樣存在一些問題,第一是我們如何將生成的\(\hat{X}\)\(X\)一一對應,這就需要我們採用更爲精巧的度量方式,即如何度量兩個分佈之間的距離;第二是我們如何生成新的\(\hat{X}\),按照樸素的想法,我們可以構造一個函數\(G\),使得\(\hat{X}=G(Z)\) ,如果能構造出這個\(G\),我們就可以通過一個任意的\(Z\),來生成\(\hat{X}\) ,而這裏的\(Z\),可以取一個已知的分佈,比如正態分佈。

目前的問題

​ 目前的問題轉化爲了如何構造\(G\),以及如何檢驗我們生成的\(\hat{X}\)是否和\(X\)具有同分布。在\(GAN\)中,這裏的\(G\)和分佈的相似度衡量都用神經網絡搞定了,一個叫做\(generator\),一個叫做\(discriminator\),這二者互相拮抗,最終使得分佈越來接近。

​ 而在我們目前的問題中,\(VAE\)提供了另外一種思路,沿着AutoEncoder的想法,AutoEncoder是通過\(encoder\)把image \(a\)編碼爲vector,叫做\(latent{\ }represention\) ,再通過\(decoder\)\(latent{\ }space\)轉爲\(\hat{a}\) ,\(\hat{a}\)\(a\)的重建圖像。

​ 但是AE針對每張圖片生成的\(latent{\ }code\)並沒有可解釋性,即sample兩個\(latent{\ } code\)之間的點輸入\(decoder\),得到的結果並不一定具有跟這兩個\(latent code\)相關的特徵。爲了解決這個問題,提出了VAE:不再採用vector來建模一個\(latent{\ }code\),而是利用一個帶有noise的高斯分佈來表示。直觀的理解,在加入noise之後,就有機會將訓練時候train的\(latent{\ }code\)在其latent space下賦予一定的變化能力,使latent space變得更加連續,從而可以在其中採樣從而生成新的圖片。

​ 我們之前生成的\(Z=\{Z_1,Z_2,\cdots,Z_n\}\),現在不再單單生成一個\(Z\),而是生成兩個vector,分別記爲\(M=\{{\mu_1},{\mu_2},\cdots\,{\mu_n}\}\),\(\Sigma=\{ {\sigma_1},{\sigma_2},\cdots,\sigma_n\}\),分別代表新生成latent code的高斯分佈的均值和方差。在sample的時候就只需要根據從標準正態分佈\(\mathcal{N}(0,1)\)中採樣一個\(e_i\),\(e_i\)來自於\(E=\{e_1,e_2,\cdots,e_n\}\),然後利用\(c_i=e_i*exp({\sigma_i})+\mu_i\)(\(reparameterization{\ }trick\)),就得到了我們所需的\(c_i\)\(c_i\)即組成我們需要的\(Z\)=\(\{c_1,c_2,\cdots,c_n\}\)

​ 這裏一方面希望\(VAE\)能夠生成儘可能豐富的數據,因此訓練的時候希望在高斯分佈中含有噪聲。另一方面優化的過程中會趨向於使圖像質量更好,因此當噪聲爲0的時候退化爲普通的\(AutoEncoder\),這種情況我們是不希望出現的。爲了平衡這種trade-off,這裏希望每個\(p(Z|X)\)能夠接近標準正態分佈,但是另一方面網絡又趨於使輸入和輸出圖像更爲接近,因此會使正態分佈的方差向0的方向優化。經過這種對抗過程,最終就能產生具有一定可解釋性的\(decoder\),同時最終得到的\(Z\)的分佈也會趨向於\(\mathcal{N}(0,1)\),可以表示爲:

​ $$p(Z)=\sum_{X} p(Z \mid X) p(X)=\sum_{X} \mathcal{N}(0, 1) p(X)=\mathcal{N}(0, I) \sum_{X} p(X)=\mathcal{N}(0, 1)$$

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar
    
    def loss_function_original(recon_x, x, mu, logvar):
        BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return BCE + KLD

​ 這裏的loss由兩部分組成,一部分是重建loss,一部分是使各個高斯分佈趨近於標準高斯分佈的loss(由KL散度推導得到)。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章