GAN 入门(PyTorch1.0版)

GAN

每一种著名的代码都值得研究,觉得过于简单,也可能是因为不够了解又或者产生了错误的见解。并不是能用就可以结束对事物的研究发现问题也许会产生更好的结果。

Generator

从随机数组中产生图片,利用全连接产生有序数组,可以转换为图片
以下是代码

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

Discriminator

判别器,用于判断真假图片,训练过程中,真假图片以1:1输入。

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity

其他

损失函数使用了二分类交叉熵(binary cross entropy)
torch.nn.BCELoss()

网络搭建使用了torch.nn.Sequential()

训练过程中D&G一起训练

由于全连接网络层数较浅,激活函数选择相对随意

但是

激活函数的制定需要根据输出的性质进行选择,若最后一层的激活函数选择过于不合理网络可能不收敛,此时的效果比无激活函数更差

网络层数较深时,尽量不要选择sigmoid作为激活函数,易出现梯度消失

在训练D时,对生成的图片进行了反向传播阻断detach函数。

图像数据的产生与还原

在此代码中,Generator最后一层激活函数为tanh,因此生成的数组值域为[-1,1]
对于输入的图像则使用了归一化(可能翻译为规范化会好一些?Normalization):
归一化
对于归一化,正则化等操作,有很多博客的解释可能存在问题,如有代码建议直接查阅代码内置解释(使用python的help函数)
(Image-0.5)/0.5
在训练过程中二者值域相等,可以正常训练

还原时采用PyTorch内置函数
from torchvision.utils import save_image
对输入的数组进行了二次normalize,此处normalize无输入参数
应对于[-1,1]之间的数字还原至[0,1],然后再还原为[0,255]整形数字

输入与输出在数据转换过程中保持了一致,同时需要注意,二者的图片均非二值图片,有灰色的像素点存在。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章