GAN的基本總結和小型demo

GAN的基本總結和小型demo

關於GANS(Generative Adversarial Networks)

屬於生成模型(generative models)

屬於無監督學習(unsupervised learning)

在不給定目的值的情況下,學習所給數據的底層結構。

目前可生成最清晰的圖像。

易於訓練(不需要統計推斷),只需要反向推斷就能夠獲得梯度。

由於訓練動態不穩定,難以優化。

基本不能做統計推斷。

屬於直接隱式密度模型,沒有明確定義概率分佈函數模型。

Generator和Discriminator

Discriminator

最大化被分類爲屬於真數據集的真數據輸入

最小化被分類爲屬於真數據集的假數據輸入

Generator

最大化被分類爲屬於真數據集的假數據輸入

這意味着用於此網絡的損耗/誤差函數(loss/error函數)要最大化

經過許多步的訓練,Generator和Discriminator都有足夠的能力,均不能再進行改進,此時Generator就能生成真實的合成數據,而Discriminator已經無法區分。

訓練GAN的基本步驟

1.採樣噪聲集和真實數據集,每個數據集具有大小m。

2.在這個數據上訓練鑑別器。

3.採樣具有大小m的不同噪聲子集。

4.根據這個數據訓練生成器。

5.從步驟1重複。

GAN的小型demo

1.導入相關庫

import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.transforms as t
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torchvision.datasets as dataset
import numpy as np
# 繪製圖像庫
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

2.設置plt屬性

plt.rcParams['figure.figsize'] = (10.0, 8.0)  # 設置大小
plt.rcParams['image.interpolation'] = 'nearest'  # 設置插值模式
plt.rcParams['image.cmap'] = 'gray'  # 設置顏色

3.圖片顯示

def show_images(images):
    images = np.reshape(images, [images.shape[0], -1])  # -1代表自動計算
    sqrtn = int(np.ceil(np.sqrt(images.shape[0])))      # np.ceil取整
    sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
​
    gs = gridspec.GridSpec(sqrtn, sqrtn)
    gs.update(wspace=0.05, hspace=0.05)
​
    for i, img in enumerate(images):
        ax = plt.subplot(gs[i])
        plt.axis('off')  # 去掉座標軸
        ax.set_xticklabels([])  # 設置x標記爲空
        ax.set_yticklabels([])  # 設置y標記爲空
        ax.set_aspect('equal')
        plt.imshow(img.reshape([sqrtimg, sqrtimg]))
    return


4.採樣函數

# 採樣函數爲自己定義的序列採樣(即按順序採樣)
class Sampler(sampler.Sampler):
    def __init__(self, num_samples, start=0):
        self.num_samples = num_samples
        self.start = start
​
    def __iter__(self):
        return iter(range(self.start, self.start + self.num_samples))
​
    def __len__(self):
        return self.num_samples

5.訓練集和測試集的設置

NUM_TRAIN = 60000   # 訓練集數量
NUM_VAL = 10000      # 測試集數量
​
NOISE_DIM = 96       # 噪聲維度
batch_size = 128     # 批尺寸
​
mnist_train = dataset.MNIST('./cs231n/datasets/MNIST_data', train=True, download=True, transform=t.ToTensor())
loader_train = DataLoader(mnist_train, batch_size=batch_size, sampler=Sampler(NUM_TRAIN, 0))
# 從0位置開始採樣NUM_TRAIN個數
​
mnist_val = dataset.MNIST('./cs231n/datasets/MNIST_data', train=True, download=True, transform=t.ToTensor())
loader_val = DataLoader(mnist_val, batch_size=batch_size, sampler=Sampler(NUM_VAL, NUM_TRAIN))
# 從NUM_TRAIN位置開始採樣NUM_VAL個數
​
imgs = loader_train.__iter__().next()[0].view(batch_size, 784).numpy().squeeze()
show_images(imgs)  # 顯示訓練集圖片

6.均勻噪聲函數

def sample_noise(batch_size, dim):
    """
    - 產生一個從-1 ~ 1的均勻噪聲函數,形狀爲 [batch_size, dim].
    參數:
    - batch_size: 整型 提供生成的batch_size
    - dim: 整型 提供生成維度
    """
    temp = torch.rand(batch_size, dim) + torch.rand(batch_size, dim)*(-1)
​
    return temp

7.平鋪函數

# 平鋪函數
​
​
class Flatten(nn.Module):
    def forward(self, x):
        n, c, h, w = x.size()  # 讀取爲n,c,h,w
        return x.view(n, -1)  # 每張圖片把c*h*w的值傳入單向量用於後期處理

8.判別器

# 判別器  判斷generator產生的圖像是否爲假,同時判斷正確的圖像是否爲真
​
​
def discriminator():
    model = nn.Sequential(
        Flatten(),
        nn.Linear(784, 256),
        nn.LeakyReLU(0.01, inplace=True),
        nn.Linear(256, 256),
        nn.LeakyReLU(0.01, inplace=True),
        nn.Linear(256, 1)
    )
    return model

9.生成器

# 生成器
​
​
def generator(noise_dim=NOISE_DIM):
    model = nn.Sequential(
        nn.Linear(noise_dim, 1024),
        nn.ReLU(inplace=True),
        nn.Linear(1024, 1024),
        nn.ReLU(inplace=True),
        nn.Linear(1024, 784),
        nn.Tanh(),
    )
    return model

10.損失函數

# GAN中指出的最大化最小化損失的算法
​
​
Bce_loss = nn.BCEWithLogitsLoss()
def discriminator_loss(logits_real, logits_fake):
    loss = None
​
    # Batch size.
    n = logits_real.size()
​
    # 目標label,全部設置爲1意味着判別器需要做到的是將正確的全識別爲正確,錯誤的全識別爲錯誤
    true_labels = Variable(torch.ones(n))
​
    real_image_loss = Bce_loss(logits_real, true_labels)  # 識別正確的爲正確
    fake_image_loss = Bce_loss(logits_fake, 1 - true_labels)  # 識別錯誤的爲錯誤
​
    loss = real_image_loss + fake_image_loss
​
    return loss
​
​
def generator_loss(logits_fake):
    n = logits_fake.size()
​
    # 生成器的作用是將所有“假”的向真的(1)靠攏
    true_labels = Variable(torch.ones(n))
​
    # 計算生成器損失
    loss = Bce_loss(logits_fake, true_labels)
​
    return loss

11.Adam優化器

def get_optimizer(model):
    """
    爲模型構建並返回一個Adam優化器
    learning rate 1e-3,
    beta1=0.5, and beta2=0.999.
    """
    # params(iterable):可用於迭代優化的參數或者定義參數組的dicts。
    # lr (float, optional) :學習率(默認: 1e-3)
    # betas (Tuple[float, float], optional):用於計算梯度的平均和平方的係數(默認: (0.9, 0.999))
​
    optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.5, 0.999))
    return optimizer

12.GAN函數

def run_a_gan(D, G, D_solver, G_solver, discriminator_loss, generator_loss, show_every=250,
              batch_size=128, noise_size=96, num_epochs=10):
    """
    訓練GAN
    - D, G: 分別爲判別器和生成器
    - D_solver, G_solver: D,G的優化器
    - discriminator_loss, generator_loss: 計算D,G的損失
    - show_every: 設置每show_every次顯示樣本
    - batch_size: 每次訓練在訓練集中取batch_size個樣本訓練
    - noise_size: 輸入進生成器的噪聲維度
    - num_epochs: 訓練迭代次數
    """
    iter_count = 0
    for epoch in range(num_epochs):
        for x, _ in loader_train:
            if len(x) != batch_size:
                continue
​
            D_solver.zero_grad()
            real_data = Variable(x)
            logits_real = D(2 * (real_data - 0.5))
​
            g_fake_seed = Variable(sample_noise(batch_size, noise_size))
            fake_images = G(g_fake_seed).detach()
            logits_fake = D(fake_images.view(batch_size, 1, 28, 28))
​
            d_total_error = discriminator_loss(logits_real, logits_fake)
            d_total_error.backward()
            D_solver.step()
​
            G_solver.zero_grad()
            g_fake_seed = Variable(sample_noise(batch_size, noise_size))
            fake_images = G(g_fake_seed)
​
            gen_logits_fake = D(fake_images.view(batch_size, 1, 28, 28))
            g_error = generator_loss(gen_logits_fake)
            g_error.backward()
            G_solver.step()
​
            print(iter_count)
​
            if iter_count % show_every == 0:
                print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error, g_error))
                imgs_numpy = fake_images.data.cpu().numpy()
                show_images(imgs_numpy[0:16])
                plt.show()
                print()
            iter_count += 1
​
    print("Completed!")
    imgs_numpy = fake_images.data.cpu().numpy()
    show_images(imgs_numpy[0:16])
    plt.show()
    print()

13.

# 創建判別器
D = discriminator()
​
# 創建生成器
G = generator()
​
# 創建D,G的優化器
D_solver = get_optimizer(D)
G_solver = get_optimizer(G)
​
# 運行GAN
run_a_gan(D, G, D_solver, G_solver, discriminator_loss, generator_loss)

最終結果爲:

 

效果很差,後續我還需要對這個初步的GAN進行完善。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章