VAE pytorch

import torch

import torch.nn as nn

from torch.autograd import Variable

import torch.optim as optim

from torchvision import transforms,datasets

import torch.nn.functional as F

import os

import scipy

import numpy as np

from scipy import misc

import math

batch_size = 64

latent_vector = 32

intermediate_vector = 256

num_class = 10

det = 1e-10

lamb = 2.5

train_data = datasets.MNIST(root='./data/', train=True, transform=transforms.ToTensor(), download= True)

test_data = datasets.MNIST(root= './data/', train=True, transform=transforms.ToTensor(), download= True)

train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size= batch_size, shuffle= True)

test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size= batch_size, shuffle=True)

class VAE(nn.Module):

    def __init__(self):

        super(VAE, self).__init__()

        self.encoder = nn.Sequential(

            nn.Conv2d(1, 32, 3, padding=1, stride=1),

            nn.LeakyReLU(negative_slope=0.2),

            nn.Conv2d(32, 32, 3, padding=1, stride=2),

            nn.LeakyReLU(negative_slope=0.2),

            nn.Conv2d(32, 64, 3, padding=1, stride=1),

            nn.LeakyReLU(negative_slope=0.2),

            nn.Conv2d(64, 64, 3, padding=1, stride=2),  # 7*7*64

            nn.LeakyReLU(negative_slope=0.2),
        )

        self.fc_mu = nn.Linear(7*7*64, latent_vector)

        self.fc_logvar = nn.Linear(7*7*64, latent_vector)

        self.fc = nn.Linear(latent_vector, 7*7*64)

        self.decoder = nn.Sequential(

            nn.ConvTranspose2d(64, 64, 4, padding=1, stride=2),

            nn.LeakyReLU(negative_slope=0.2),

            nn.ConvTranspose2d(64, 32, 3, padding=1, stride=1),

            nn.LeakyReLU(negative_slope=0.2),

            nn.ConvTranspose2d(32, 32, 4, padding=1, stride=2),

            nn.LeakyReLU(negative_slope=0.2),

            nn.ConvTranspose2d(32, 1, 3, padding=1, stride=1),

            nn.Sigmoid()
        )

    def Reparameter(self, mu, logvar):

        parameter = Variable(torch.randn(mu.size(0), mu.size(1)))

        return parameter * torch.exp(logvar/2) + mu


    def forward(self, x):

        mu = self.encoder(x)

        logvar = self.encoder(x)

        mu = self.fc_mu(mu.view(mu.size(0), -1))

        logvar = self.fc_logvar(logvar.view(logvar.size(0), -1))

        z = self.Reparameter(mu, logvar)

        vector = self.fc(z).view(z.size(0), 64, 7, 7)

        return self.decoder(vector), mu, logvar

model = VAE()

MSE_Loss = nn.MSELoss(size_average=False)

#MSE_Loss = F.binary_cross_entropy(reduction='sum')

def loss_function(input, output, mu, logvar):

    Mse = MSE_Loss(input, output)

    # Mse = 0.5 * torch.mean((input - output).pow(2), 0)

    KL_loss = 0.5 * torch.sum(-logvar + mu.pow(2) + logvar.exp() - 1)

    # KL_loss = -0.5 * (temp_logvar - z_se.pow(2))
    #
    # KL_loss = torch.mean(torch.tensordot(torch.unsqueeze(y, 1), KL_loss), 0)
    #
    # cat_loss = torch.mean(y * torch.log(y + det), 0)
    #
    # return lamb * torch.sum(Mse) + torch.sum(KL_loss) + torch.sum(cat_loss)

    return Mse + KL_loss

def save_image(output, size, path, Color):

    h, w = output.shape[1], output.shape[2]

    if Color is True:

        image = np.zeros((w * size[0], h * size[1], 3))

    else:

        image = np.zeros((w * size[0], h * size[1]))

    for index, data in enumerate (output):

        i = index % size[0]

        j = math.floor(index / size[1])

        if Color is True:

            image[h*j : h*j+h, w*i : w*i+w, :] = data

        else:

            image[h*j : h*j+j, w*i : w*i+w] = data

    scipy.misc.toimage((image*255), cmin=0, cmax=255).save(path)

def rescale_image(image):

    return (image/1.5+0.5)*255

optimizer = optim.SGD(model.parameters(), lr= 0.0001)

def train():

    for epoch in range(1,10):

        for i, (data, _) in enumerate (train_loader):

            tensor_data = Variable(data)

            output, mu, logvar= model(tensor_data)

            optimizer.zero_grad()

            loss = loss_function(tensor_data, output, mu, logvar)

            loss.backward()

            optimizer.step()

            if i % 50 == 0:

                if not os.path.exists("./image"):

                   os.mkdir("./image")

                np_output = output.detach().numpy()

                np_output = np_output.swapaxes(1,2).swapaxes(2,3)

                save_image(np_output, [8,8], './image/image_{}.png'.format(i), True)

                print("loss={}".format(loss))

train()


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章