50 行 PyTorch 代码搞定 GAN

最近,一篇非常火的报道,使用pytorch 加 50 行核心代码模拟 GAN  对抗神经网络,自己尝试走了一遍,并对源码提出自己的理解。原文链接如下 https://medium.com/@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f#.nr3akui9z 

code 地址 https://github.com/devnag/pytorch-generative-adversarial-networks

首先配置pytorch, 这个直接去github 上的pytorch官网按照教程去配,链接地址 https://github.com/pytorch/pytorch

之后把code 下下来,直接运行就可以了。


数据的产生,正样本也就是使用一个均值为4,标准差为1.25的高斯分布,

torch.Tensor(np.random.normal(mu, sigma, (1, n)))
负样本是随机数

torch.rand(m, n) 

两个网路的定义,一个是生成网络,将其参数打印出来如下

Generator (

  (map1): Linear (1 -> 50)
  (map2): Linear (50 -> 50)
  (map3): Linear (50 -> 1)
)

输入是(1L,1L)

另一个是判别网路,输入是(200L,1L)

Discriminator (
  (map1): Linear (200 -> 50)
  (map2): Linear (50 -> 50)
  (map3): Linear (50 -> 1)
)

在初始化网络的时,作者对数据进行preprocess预处理, 使用的方法如下

def decorate_with_diffs(data, exponent):
    mean = torch.mean(data.data, 1)
    mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0])
    diffs = torch.pow(data - Variable(mean_broadcast), exponent)
    return torch.cat([data, diffs], 1)
分析其实现过程,是||data-mean||的二阶范数,也就是(data-mean)^2,然后将这个值放到data后面,这样本来data是(100L,1L)就变成(200L,1L)

自己很疑惑为什么对数据进行处理之后还要把原始数据也串起来,为何不直接return diffs?

于是自己尝试了返回data+diffs 的方式和直接返回diffs的方式,打印出了日志,发现使用data+diffs 的方式 比 直接使用diffs的方式,G生成器产生的数据有很大不同,前者产生的假的数据更接近真实的数据,也就是前者的方式学习得到的生成器要比后者更加聪明,我们可以看一下训练结果,我截取了一部分。

使用data+diffs

只使用diffs


自己给出解释,如图深度学习训练图片分类一样,从图片中提取feature, 一般使用feature的维度越高,网络的表达能力就更强。这个也类似, 使用data+diffs维度为200,而只用diffs的维度是100,毫无疑问前者可供学习的东西更多更丰富,所以前者的生成器产生的数据迷惑判别器的难度更大,对判别器的提高能力也就更大。






發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章