50 行 PyTorch 代碼搞定 GAN

最近,一篇非常火的報道,使用pytorch 加 50 行核心代碼模擬 GAN  對抗神經網絡,自己嘗試走了一遍,並對源碼提出自己的理解。原文鏈接如下 https://medium.com/@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f#.nr3akui9z 

code 地址 https://github.com/devnag/pytorch-generative-adversarial-networks

首先配置pytorch, 這個直接去github 上的pytorch官網按照教程去配,鏈接地址 https://github.com/pytorch/pytorch

之後把code 下下來,直接運行就可以了。


數據的產生,正樣本也就是使用一個均值爲4,標準差爲1.25的高斯分佈,

torch.Tensor(np.random.normal(mu, sigma, (1, n)))
負樣本是隨機數

torch.rand(m, n) 

兩個網路的定義,一個是生成網絡,將其參數打印出來如下

Generator (

  (map1): Linear (1 -> 50)
  (map2): Linear (50 -> 50)
  (map3): Linear (50 -> 1)
)

輸入是(1L,1L)

另一個是判別網路,輸入是(200L,1L)

Discriminator (
  (map1): Linear (200 -> 50)
  (map2): Linear (50 -> 50)
  (map3): Linear (50 -> 1)
)

在初始化網絡的時,作者對數據進行preprocess預處理, 使用的方法如下

def decorate_with_diffs(data, exponent):
    mean = torch.mean(data.data, 1)
    mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0])
    diffs = torch.pow(data - Variable(mean_broadcast), exponent)
    return torch.cat([data, diffs], 1)
分析其實現過程,是||data-mean||的二階範數,也就是(data-mean)^2,然後將這個值放到data後面,這樣本來data是(100L,1L)就變成(200L,1L)

自己很疑惑爲什麼對數據進行處理之後還要把原始數據也串起來,爲何不直接return diffs?

於是自己嘗試了返回data+diffs 的方式和直接返回diffs的方式,打印出了日誌,發現使用data+diffs 的方式 比 直接使用diffs的方式,G生成器產生的數據有很大不同,前者產生的假的數據更接近真實的數據,也就是前者的方式學習得到的生成器要比後者更加聰明,我們可以看一下訓練結果,我截取了一部分。

使用data+diffs

只使用diffs


自己給出解釋,如圖深度學習訓練圖片分類一樣,從圖片中提取feature, 一般使用feature的維度越高,網絡的表達能力就更強。這個也類似, 使用data+diffs維度爲200,而只用diffs的維度是100,毫無疑問前者可供學習的東西更多更豐富,所以前者的生成器產生的數據迷惑判別器的難度更大,對判別器的提高能力也就更大。






發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章