Auto-Encoder實戰
對於監督學習
loss = criteon(predict, label)
無監督學習
loss = criteon(x_hat, x)
main.py
import torch from torch.utils.data import DataLoader from torchvision import transforms,datasets from torch import nn,optim from ae import AE import visdom def main(): mnist_train = datasets.MNIST('D:/py/dataset/', True, transform=transforms.Compose([ transforms.ToTensor() ]),download=True) mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True) mnist_test = datasets.MNIST('D:/py/dataset/', False, transform=transforms.Compose([ transforms.ToTensor() ]), download=True) mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True) x,_ = iter(mnist_train).next() print('x:', x.shape) device = torch.device('cuda') model = AE().to(device) criteon = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=1e-3) print(model) viz = visdom.Visdom() for epoch in range(1000): for batchidx, (x,_) in enumerate(mnist_train): #[b,1,28,28] x = x.to(device) x_hat = model(x) loss = criteon(x_hat, x) #反向傳播 optimizer.zero_grad() loss.backward() optimizer.step() print(epoch,'loss:',loss.item()) x,_ = iter(mnist_test).next() x = x.to(device) with torch.no_grad(): x_hat = model(x) viz.images(x, nrow=8, win='x',opts=dict(title='x')) viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat')) if __name__ == '__main__': main()
ae.py
import torch from torch import nn class AE(nn.Module): def __init__(self): super(AE,self).__init__() #[b,784]=>[b,20] self.encoder = nn.Sequential( nn.Linear(784,256), #降維,784降到256 nn.ReLU(), nn.Linear(256,64), #降維 nn.ReLU(), nn.Linear(64,20), #降維 nn.ReLU() ) #[b,20] => [b,784] #升維 self.decoder = nn.Sequential( nn.Linear(20,64), nn.ReLU(), nn.Linear(64,256), nn.ReLU(), nn.Linear(256,784), nn.Sigmoid()#壓縮到[0,1] ) def forward(self, x): batchsz = x.size(0) #flatten x = x.view(batchsz,784) #encoder x = self.encoder(x) #decoder x = self.decoder(x) #reshape x = x.view(batchsz,1,28,28) return x
VAE實戰
在ae.py的基礎上
vae.py
import torch from torch import nn import numpy as np class VAE(nn.Module): def __init__(self): super(VAE,self).__init__() #[b,784]=>[b,20] #u : [b,10] #sigma:[b,10] self.encoder = nn.Sequential( nn.Linear(784,256), #降維,784降到256 nn.ReLU(), nn.Linear(256,64), #降維 nn.ReLU(), nn.Linear(64,20), #降維 nn.ReLU() ) #[b,20] => [b,784] #升維 self.decoder = nn.Sequential( nn.Linear(10,64), nn.ReLU(), nn.Linear(64,256), nn.ReLU(), nn.Linear(256,784), nn.Sigmoid()#壓縮到[0,1] ) def forward(self, x): batchsz = x.size(0) #flatten x = x.view(batchsz,784) #encoder #[b,20], including mu and sigma h_ = self.encoder(x) #[b,20] => [b,10] and [b,10] mu, sigma = h_.chunk(2, dim=1) #reparametrize trick #epison ~ N(0,1) h = mu + sigma * torch.randn_like(sigma) #decoder x_hat = self.decoder(h) #reshape x_hat = x_hat.view(batchsz,1,28,28) #KL divergence kld = 0.5*torch.sum( torch.pow(mu,2) + torch.pow(sigma, 2)- torch.log(1e-8 + torch.pow(sigma,2)) -1 ) / (batchsz*28*28) return x_hat, kld
main.py
import torch from torch.utils.data import DataLoader from torchvision import transforms,datasets from torch import nn,optim from vae import VAE import visdom def main(): mnist_train = datasets.MNIST('D:/py/dataset/', True, transform=transforms.Compose([ transforms.ToTensor() ]),download=True) mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True) mnist_test = datasets.MNIST('D:/py/dataset/', False, transform=transforms.Compose([ transforms.ToTensor() ]), download=True) mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True) x,_ = iter(mnist_train).next() print('x:', x.shape) device = torch.device('cuda') model = VAE().to(device) criteon = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=1e-3) print(model) viz = visdom.Visdom() for epoch in range(1000): for batchidx, (x,_) in enumerate(mnist_train): #[b,1,28,28] x = x.to(device) x_hat, kld = model(x) loss = criteon(x_hat, x) if kld is not None: elbo = -loss - 1.0*kld loss = -elbo #反向傳播 optimizer.zero_grad() loss.backward() optimizer.step() print(epoch,'loss:',loss.item(), 'kld:',kld.item()) x,_ = iter(mnist_test).next() x = x.to(device) with torch.no_grad(): x_hat ,kld= model(x) viz.images(x, nrow=8, win='x',opts=dict(title='x')) viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat')) if __name__ == '__main__': main()
x_hat是它生成的圖
loss和kld差距還是挺大的