GAN的基本總結和小型demo
關於GANS(Generative Adversarial Networks)
屬於生成模型(generative models)
屬於無監督學習(unsupervised learning)
在不給定目的值的情況下,學習所給數據的底層結構。
目前可生成最清晰的圖像。
易於訓練(不需要統計推斷),只需要反向推斷就能夠獲得梯度。
由於訓練動態不穩定,難以優化。
基本不能做統計推斷。
屬於直接隱式密度模型,沒有明確定義概率分佈函數模型。
Generator和Discriminator
Discriminator
最大化被分類爲屬於真數據集的真數據輸入
最小化被分類爲屬於真數據集的假數據輸入
Generator
最大化被分類爲屬於真數據集的假數據輸入
這意味着用於此網絡的損耗/誤差函數(loss/error函數)要最大化
經過許多步的訓練,Generator和Discriminator都有足夠的能力,均不能再進行改進,此時Generator就能生成真實的合成數據,而Discriminator已經無法區分。
訓練GAN的基本步驟
1.採樣噪聲集和真實數據集,每個數據集具有大小m。
2.在這個數據上訓練鑑別器。
3.採樣具有大小m的不同噪聲子集。
4.根據這個數據訓練生成器。
5.從步驟1重複。
GAN的小型demo
1.導入相關庫
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.transforms as t
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torchvision.datasets as dataset
import numpy as np
# 繪製圖像庫
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
2.設置plt屬性
plt.rcParams['figure.figsize'] = (10.0, 8.0) # 設置大小
plt.rcParams['image.interpolation'] = 'nearest' # 設置插值模式
plt.rcParams['image.cmap'] = 'gray' # 設置顏色
3.圖片顯示
def show_images(images):
images = np.reshape(images, [images.shape[0], -1]) # -1代表自動計算
sqrtn = int(np.ceil(np.sqrt(images.shape[0]))) # np.ceil取整
sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
gs = gridspec.GridSpec(sqrtn, sqrtn)
gs.update(wspace=0.05, hspace=0.05)
for i, img in enumerate(images):
ax = plt.subplot(gs[i])
plt.axis('off') # 去掉座標軸
ax.set_xticklabels([]) # 設置x標記爲空
ax.set_yticklabels([]) # 設置y標記爲空
ax.set_aspect('equal')
plt.imshow(img.reshape([sqrtimg, sqrtimg]))
return
4.採樣函數
# 採樣函數爲自己定義的序列採樣(即按順序採樣)
class Sampler(sampler.Sampler):
def __init__(self, num_samples, start=0):
self.num_samples = num_samples
self.start = start
def __iter__(self):
return iter(range(self.start, self.start + self.num_samples))
def __len__(self):
return self.num_samples
5.訓練集和測試集的設置
NUM_TRAIN = 60000 # 訓練集數量
NUM_VAL = 10000 # 測試集數量
NOISE_DIM = 96 # 噪聲維度
batch_size = 128 # 批尺寸
mnist_train = dataset.MNIST('./cs231n/datasets/MNIST_data', train=True, download=True, transform=t.ToTensor())
loader_train = DataLoader(mnist_train, batch_size=batch_size, sampler=Sampler(NUM_TRAIN, 0))
# 從0位置開始採樣NUM_TRAIN個數
mnist_val = dataset.MNIST('./cs231n/datasets/MNIST_data', train=True, download=True, transform=t.ToTensor())
loader_val = DataLoader(mnist_val, batch_size=batch_size, sampler=Sampler(NUM_VAL, NUM_TRAIN))
# 從NUM_TRAIN位置開始採樣NUM_VAL個數
imgs = loader_train.__iter__().next()[0].view(batch_size, 784).numpy().squeeze()
show_images(imgs) # 顯示訓練集圖片
6.均勻噪聲函數
def sample_noise(batch_size, dim):
"""
- 產生一個從-1 ~ 1的均勻噪聲函數,形狀爲 [batch_size, dim].
參數:
- batch_size: 整型 提供生成的batch_size
- dim: 整型 提供生成維度
"""
temp = torch.rand(batch_size, dim) + torch.rand(batch_size, dim)*(-1)
return temp
7.平鋪函數
# 平鋪函數
class Flatten(nn.Module):
def forward(self, x):
n, c, h, w = x.size() # 讀取爲n,c,h,w
return x.view(n, -1) # 每張圖片把c*h*w的值傳入單向量用於後期處理
8.判別器
# 判別器 判斷generator產生的圖像是否爲假,同時判斷正確的圖像是否爲真
def discriminator():
model = nn.Sequential(
Flatten(),
nn.Linear(784, 256),
nn.LeakyReLU(0.01, inplace=True),
nn.Linear(256, 256),
nn.LeakyReLU(0.01, inplace=True),
nn.Linear(256, 1)
)
return model
9.生成器
# 生成器
def generator(noise_dim=NOISE_DIM):
model = nn.Sequential(
nn.Linear(noise_dim, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 784),
nn.Tanh(),
)
return model
10.損失函數
# GAN中指出的最大化最小化損失的算法
Bce_loss = nn.BCEWithLogitsLoss()
def discriminator_loss(logits_real, logits_fake):
loss = None
# Batch size.
n = logits_real.size()
# 目標label,全部設置爲1意味着判別器需要做到的是將正確的全識別爲正確,錯誤的全識別爲錯誤
true_labels = Variable(torch.ones(n))
real_image_loss = Bce_loss(logits_real, true_labels) # 識別正確的爲正確
fake_image_loss = Bce_loss(logits_fake, 1 - true_labels) # 識別錯誤的爲錯誤
loss = real_image_loss + fake_image_loss
return loss
def generator_loss(logits_fake):
n = logits_fake.size()
# 生成器的作用是將所有“假”的向真的(1)靠攏
true_labels = Variable(torch.ones(n))
# 計算生成器損失
loss = Bce_loss(logits_fake, true_labels)
return loss
11.Adam優化器
def get_optimizer(model):
"""
爲模型構建並返回一個Adam優化器
learning rate 1e-3,
beta1=0.5, and beta2=0.999.
"""
# params(iterable):可用於迭代優化的參數或者定義參數組的dicts。
# lr (float, optional) :學習率(默認: 1e-3)
# betas (Tuple[float, float], optional):用於計算梯度的平均和平方的係數(默認: (0.9, 0.999))
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.5, 0.999))
return optimizer
12.GAN函數
def run_a_gan(D, G, D_solver, G_solver, discriminator_loss, generator_loss, show_every=250,
batch_size=128, noise_size=96, num_epochs=10):
"""
訓練GAN
- D, G: 分別爲判別器和生成器
- D_solver, G_solver: D,G的優化器
- discriminator_loss, generator_loss: 計算D,G的損失
- show_every: 設置每show_every次顯示樣本
- batch_size: 每次訓練在訓練集中取batch_size個樣本訓練
- noise_size: 輸入進生成器的噪聲維度
- num_epochs: 訓練迭代次數
"""
iter_count = 0
for epoch in range(num_epochs):
for x, _ in loader_train:
if len(x) != batch_size:
continue
D_solver.zero_grad()
real_data = Variable(x)
logits_real = D(2 * (real_data - 0.5))
g_fake_seed = Variable(sample_noise(batch_size, noise_size))
fake_images = G(g_fake_seed).detach()
logits_fake = D(fake_images.view(batch_size, 1, 28, 28))
d_total_error = discriminator_loss(logits_real, logits_fake)
d_total_error.backward()
D_solver.step()
G_solver.zero_grad()
g_fake_seed = Variable(sample_noise(batch_size, noise_size))
fake_images = G(g_fake_seed)
gen_logits_fake = D(fake_images.view(batch_size, 1, 28, 28))
g_error = generator_loss(gen_logits_fake)
g_error.backward()
G_solver.step()
print(iter_count)
if iter_count % show_every == 0:
print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error, g_error))
imgs_numpy = fake_images.data.cpu().numpy()
show_images(imgs_numpy[0:16])
plt.show()
print()
iter_count += 1
print("Completed!")
imgs_numpy = fake_images.data.cpu().numpy()
show_images(imgs_numpy[0:16])
plt.show()
print()
13.
# 創建判別器
D = discriminator()
# 創建生成器
G = generator()
# 創建D,G的優化器
D_solver = get_optimizer(D)
G_solver = get_optimizer(G)
# 運行GAN
run_a_gan(D, G, D_solver, G_solver, discriminator_loss, generator_loss)
最終結果爲:
效果很差,後續我還需要對這個初步的GAN進行完善。