用pytorch实现对抗生成网络

最近在学习深度学习编程,采用的深度学习框架是pytorch,看的书主要是陈云编著的《深度学习框架PyTorch入门与实践》、廖星宇编著的《深度学习入门之PyTorch》、肖志清的《神经网络与PyTorch实践》,都是入门的学习材料,适合初学者。

通过近1个多月的学习,基本算是入门了,后面将深度学习与实践。这里分享一个《神经网络与PyTorch实践》中对抗生成网络的例子。它是用对抗生成网络的方法,训练CIFAR-10的数据集,训练模型。

生成网络gnet将大小为(64,11)的潜在张量转化为大小为(3,32,32)的假数据;鉴别网络dnet将大小为(3,32,32)的数据转化为大小为
(1,1,1)的对数赔率张量。下面是整个模型的python代码,包括(1)数据加载,(2)模型搭建,(3)模型训练与模型测试。


import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10,CIFAR100
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchviz import make_dot

dataset = CIFAR100(root='./data',
                  download=True,
                  transform= transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

#check the data
#for batch_idx, data in enumerate(dataloader):
#    real_images, _ = data
#    print('real_images size = {}'.format(real_images.size()))
#    batch_size = real_images.size(0)
#    print('#{} has {} images.'.format(batch_idx, batch_size))
#    if batch_idx %100 ==0:
#        path = './data/CIFAR10_shuffled_batch{:03d}.png'.format(batch_idx)
#        save_image(real_images, path, normalize=True)


#construct the generator and discrimiter network
latent_size=64  #潜在大小
n_channel=3   #输出通道数
n_g_feature=64   #生成网络隐藏层大小
#construct the generator 
gnet= nn.Sequential(
    #输入大小 == (64, 1, 1)
    nn.ConvTranspose2d(latent_size, 4 * n_g_feature, kernel_size=4, bias=False),
    nn.BatchNorm2d(4*n_g_feature),
    nn.ReLU(),
    #大小 = (256,4,4)
    nn.ConvTranspose2d(4*n_g_feature, 2 * n_g_feature, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(2*n_g_feature),
    nn.ReLU(),
    #大小 = (128, 8,8)
    nn.ConvTranspose2d(2*n_g_feature, n_g_feature, kernel_size=4, stride=2, padding=1, bias= False),
    nn.BatchNorm2d(n_g_feature),
    nn.ReLU(),
    #大小 = (64,16,16)
    nn.ConvTranspose2d(n_g_feature, n_channel, kernel_size=4, stride=2, padding=1),
    nn.Sigmoid(),
    #图片大小 = (3, 32, 32)
)

#define the instance of GeneratorNet
print(gnet)
if torch.cuda.is_available():
    gnet.to(torch.device('cuda:0'))

#construct the discrimator
n_d_feature = 64  #鉴别网络隐藏层大小
dnet = nn.Sequential(
    #图片大小 = (3,32,32)
    nn.Conv2d(n_channel, n_d_feature, kernel_size=4, stride=2, padding=1),
    nn.LeakyReLU(0.2),
    #大小 = (63,16,16)
    nn.Conv2d(n_d_feature, 2*n_d_feature, kernel_size=4, stride=2, padding=1, bias= False),
    nn.BatchNorm2d(2*n_d_feature),
    nn.LeakyReLU(0.2),
    #大小 = (128, 8,8)
    nn.Conv2d(2*n_d_feature, 4*n_d_feature, kernel_size=4, stride=2, padding=1, bias= False),
    nn.BatchNorm2d(4*n_d_feature),
    nn.LeakyReLU(0.2),
    #大小 = (256,4,4)
    nn.Conv2d(4*n_d_feature, 1, kernel_size=4),
    #对数赔率张量大小=(1,1,1)    
    #nn.Sigmoid()
)
print(dnet)
if torch.cuda.is_available():
    dnet.to(torch.device('cuda:0'))

#initialization for gnet and dnet
def weights_init(m):
    if type(m) in [nn.ConvTranspose2d, nn.Conv2d]:
        init.xavier_normal_(m.weight)
    elif type(m) == nn.BatchNorm2d:
        init.normal_(m.weight, 1.0, 0.02)
        init.constant_(m.bias, 0)

gnet.apply(weights_init)
dnet.apply(weights_init)

#网络的训练和使用
#要构造一个损失函数并对它进行优化
#定义损失
criterion = nn.BCEWithLogitsLoss()
#定义优化器
goptimizer = torch.optim.Adam(gnet.parameters(), lr=0.0002, betas=(0.5, 0.999))
doptimizer = torch.optim.Adam(dnet.parameters(), lr=0.0002, betas=(0.5, 0.999))

#用于测试的噪声,用来查看相同的潜在张量在训练过程中生成图片的变换
batch_size=64
fixed_noises = torch.randn(batch_size, latent_size, 1,1)

#save the net to file for check
y=gnet(fixed_noises)
vise_graph = make_dot(y, params=dict(gnet.named_parameters()))
vise_graph.view(filename='gnet')


y=dnet(y)
vise_graph = make_dot(y)
vise_graph.view(filename='dnet')


#训练过程
epoch_num=10
for epoch in range(epoch_num):
    for batch_idx, data in enumerate(dataloader):
        #载入本批次数据
        real_images,_ = data
        batch_size = real_images.size(0)
        
        #训练鉴别网络
        labels = torch.ones(batch_size) #设置真实数据对应标签为1
        preds = dnet(real_images)  #对真实数据进行判别
        outputs = preds.reshape(-1)
        dloss_real = criterion(outputs, labels)  #真实数据的鉴别损失
        dmean_real = outputs.sigmoid().mean()  #计算鉴别器将多少比例的真实数据判定为真,仅用于输出显示
        
        noises = torch.randn(batch_size, latent_size, 1,1)  #潜在噪声
        fake_images = gnet(noises)  #生成假数据
        labels = torch.zeros(batch_size)  #假数据对应标签为0
        fake = fake_images.detach()  #是的梯度的计算不回溯到生成网络,可用于加快训练速度。删去此步,结果不变
        preds = dnet(fake)
        outputs = preds.view(-1)
        dloss_fake = criterion(outputs, labels)  #假数据的鉴别损失
        dmean_fake = outputs.sigmoid().mean()  #计算鉴别器将多少比例的假数据判定为真,仅用于输出显示
        
        dloss = dloss_real+dloss_fake
        dnet.zero_grad()
        dloss.backward()
        doptimizer.step()
        
        #训练生成网络
        labels = torch.ones(batch_size) #生成网络希望所有生成的数据都是被认为时真的
        preds = dnet(fake_images) #让假数据通过假别网络
        outputs = preds.view(-1)
        gloss = criterion(outputs, labels)  #从真数据看到的损失
        gmean_fake = outputs.sigmoid().mean() #计算鉴别器将多少比例的假数据判断为真,仅用于输出显示
        gnet.zero_grad()
        gloss.backward()
        goptimizer.step()
        
        #输出本步训练结果
        print('[{}/{}]'.format(epoch, epoch_num)+
              '[{}/{}]'.format(batch_idx, len(dataloader))+
              '鉴别网络损失:{:g} 生成网络损失:{:g}'.format(dloss, gloss)+
              '真实数据判真比例:{:g} 假数据判真比例:{:g}/{:g}'.format(dmean_real, dmean_fake, gmean_fake))
        if batch_idx %100 == 0:
            fake = gnet(fixed_noises)  #由固定潜在征粮生成假数据
            save_image(fake, './data/images_epoch{:02d}_batch{:03d}.png'.format(epoch, batch_idx))  #保存假数据
        
        
#保存训练的网络
torch.save(gnet, 'gnet.pkl')
torch.save(dnet, 'dnet.pkl')

结果如下

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章