如何用 PyTorch 構建 GAN?

{"type":"doc","content":[{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"生成對抗網絡(Generative Adversarial Network,GAN)由 Goodfellow 等人在 2014 年提出,它徹底改變了計算機視覺中的圖像生成領域:沒有人能夠相信這些令人驚歎而生動的圖像實際上是純粹由機器生成的。"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"事實上,人們曾經認爲生成的任務是不可能的,並且被 GAN 的力量所震驚,因爲傳統上,根本沒有任何事實可以比較我們生成的圖像。"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"本文介紹了創建 GAN 背後的簡單直覺,然後介紹了通過 PyTorch 實現的卷積 GAN 及其訓練過程。"}]},{"type":"heading","attrs":{"align":null,"level":2},"content":[{"type":"text","text":"GAN 背後的直覺"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"不同於傳統分類方法,我們的網絡預測可以直接與事實的正確答案相比較,而生成圖像的“正確性”是很難定義和衡量的。Goodfellow 等人在他們的原創論文《生成對抗網絡》("},{"type":"text","marks":[{"type":"italic"}],"text":"Generative Adversarial Network"},{"type":"text","text":")中提出了一個有趣的想法:使用經過訓練的分類器來區分生成的圖像和實際圖像。如果存在這樣的分類器,我們可以創建並訓練一個生成器網絡,直到它輸出的圖像能完全騙過分類器。"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"image","attrs":{"src":"https:\/\/static001.geekbang.org\/resource\/image\/72\/8a\/722e329b3b82210b6b738b9c3bbd7d8a.jpg","alt":null,"title":null,"style":[{"key":"width","value":"75%"},{"key":"bordertype","value":"none"}],"href":null,"fromPaste":true,"pastePass":true}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":"center","origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":10}}],"text":" GAN管道"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"GAN 是這一過程的產物:它包含一個根據給定的數據集生成圖像的生成器,以及一個區分圖像是真實的還是生成的判別器(分類器)。GAN 的詳細管道見圖 1。"}]},{"type":"heading","attrs":{"align":null,"level":3},"content":[{"type":"text","text":"損失函數"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"對生成器和判別器進行優化都很困難,因爲正如你所想象的那樣,這兩個網絡的目標完全相反:生成器希望儘可能地創造出真實的東西,但判別器希望區分生成的材料。"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"爲了說明這一點,我們讓 D(x) 是判別器的輸出,也就是 x 是真實圖像的概率,而 G(z) 是我們的生成器的輸出。判別器類似於一個二元分類器,因此判別器的目標是使函數最大化:"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"本質上是二元交叉熵損失,沒有開頭的負號。另一方面,生成器的目標是使判別器做出正確判斷的機會最小化,因此它的目標是最小化函數。所以,最終的損失函數將是兩個分類器之間的一個極小極大博弈(minimax game),具體如下:"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"image","attrs":{"src":"https:\/\/static001.geekbang.org\/resource\/image\/4f\/7a\/4f3489ecc427fa4f5001f0074c8cd47a.jpg","alt":null,"title":null,"style":[{"key":"width","value":"75%"},{"key":"bordertype","value":"none"}],"href":null,"fromPaste":true,"pastePass":true}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"從理論上講,這將收斂到判別器,預測所有事件的概率爲 0.5。"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"但在實踐中,極小極大博弈往往會導致網絡無法收斂,因此仔細調整訓練過程非常重要。像學習率這樣的超參數對於訓練 GAN 時顯然更爲重要:一個微小的變化會導致 GAN 產生一個輸出,而與輸入噪聲無關。"}]},{"type":"heading","attrs":{"align":null,"level":2},"content":[{"type":"text","text":"運算環境"}]},{"type":"heading","attrs":{"align":null,"level":3},"content":[{"type":"text","text":"庫"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"我們通過 PyTorch 庫(包括 torchvision)來構建整個程序。GAN 的生成結果的可視化是通過 Matplotlib 庫繪製的。下面的代碼導入了所有的庫:"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"blockquote","content":[{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"link","attrs":{"href":"https:\/\/gist.github.com\/ttchengab\/1ae059ec0b37238c64ddc6e308e6f887#file-importgan-py","title":"","type":null},"content":[{"type":"text","text":"importGAN.py"}]}]}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"codeblock","attrs":{"lang":"python"},"content":[{"type":"text","text":"\"\"\"\nImport necessary libraries to create a generative adversarial network\nThe code is mainly developed using the PyTorch library\n\"\"\"\nimport time\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.utils.data import DataLoader\nfrom torchvision import datasets\nfrom torchvision.transforms import transforms\nfrom model import discriminator, generator\nimport numpy as np\nimport matplotlib.pyplot as plt\n"}]},{"type":"heading","attrs":{"align":null,"level":3},"content":[{"type":"text","text":"數據集"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"在 GAN 訓練中,數據集是一個重要方面。圖像的非結構化性質意味着任何給定的類別(如狗、貓或手寫的數字)都可以有一個可能的數據分佈,而這種分佈最終是 GAN 生成內容的基礎。"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"爲了演示,本文將使用最簡單的 "},{"type":"link","attrs":{"href":"https:\/\/gas.graviti.com\/dataset\/hellodataset\/MNIST?utm_medium=0708Taying_2","title":"","type":null},"content":[{"type":"text","text":"MNIST 數據集"}]},{"type":"text","text":",其中包含 60000 張從 0 到 9 的手寫數字圖像。事實上,像 MNIST 這樣的非結構化數據集可以在 "},{"type":"link","attrs":{"href":"https:\/\/graviti.com\/?utm_medium=0708Taying_2","title":"","type":null},"content":[{"type":"text","text":"Graviti"}]},{"type":"text","text":" 上找到。這是一家年輕的創業公司,他們希望通過非結構化數據集爲社區提供幫助,在他們的"},{"type":"link","attrs":{"href":"https:\/\/gas.graviti.com\/open-datasets\/?utm_medium=0708Taying_2","title":"","type":null},"content":[{"type":"text","text":"平臺"}]},{"type":"text","text":"上有一些最好的公共非結構化數據集,包括 MNIST。"}]},{"type":"heading","attrs":{"align":null,"level":3},"content":[{"type":"text","text":"硬件要求"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"最好的方法是用 GPU 訓練神經網絡,它可以顯著地提高訓練速度。但是,如果只有 CPU 可用,你仍然可以測試程序。要使你的程序能夠自行確定硬件,你可以使用以下方法:"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"blockquote","content":[{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"link","attrs":{"href":"https:\/\/gist.github.com\/ttchengab\/dde281557818f0d701ececaeeb7f1b0a#file-torchdevice-py","title":"","type":null},"content":[{"type":"text","text":"torchDevice.py"}]}]}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"codeblock","attrs":{"lang":"python"},"content":[{"type":"text","text":"\"\"\"\nDetermine if any GPUs are available\n\"\"\"\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n"}]},{"type":"heading","attrs":{"align":null,"level":2},"content":[{"type":"text","text":"實施"}]},{"type":"heading","attrs":{"align":null,"level":3},"content":[{"type":"text","text":"網絡架構"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"由於數字的簡單性,這兩種架構——判別器和生成器,都是由全連接層構建的。請注意,在某些情況下,全連接的 GAN 也比 DCGAN 略微容易收斂。"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"以下是兩種架構的 PyTorch 實現:"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"link","attrs":{"href":"https:\/\/gist.github.com\/ttchengab\/2d208d73cc4f191b1641276dd64d110c#file-ganarchitecture-py","title":"","type":null},"content":[{"type":"text","text":"GANArchitecture.py"}]}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"codeblock","attrs":{"lang":"python"},"content":[{"type":"text","text":"\"\"\"\nNetwork Architectures\nThe following are the discriminator and generator architectures\n\"\"\"\n\nclass discriminator(nn.Module):\n def __init__(self):\n super(discriminator, self).__init__()\n self.fc1 = nn.Linear(784, 512)\n self.fc2 = nn.Linear(512, 1)\n self.activation = nn.LeakyReLU(0.1)\n\n def forward(self, x):\n x = x.view(-1, 784)\n x = self.activation(self.fc1(x))\n x = self.fc2(x)\n return nn.Sigmoid()(x)\n\nclass generator(nn.Module):\n def __init__(self):\n super(generator, self).__init__()\n self.fc1 = nn.Linear(128, 1024)\n self.fc2 = nn.Linear(1024, 2048)\n self.fc3 = nn.Linear(2048, 784)\n self.activation = nn.ReLU()\n\n def forward(self, x):\n x = self.activation(self.fc1(x))\n x = self.activation(self.fc2(x))\n x = self.fc3(x)\n x = x.view(-1, 1, 28, 28)\n return nn.Tanh()(x)\n"}]},{"type":"heading","attrs":{"align":null,"level":3},"content":[{"type":"text","text":"訓練"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"在訓練 GAN 時,我們優化了判別器的結果,同時也改進了我們的生成器。這樣,在每次迭代過程中會有兩個相互矛盾的損失來同時優化它們。我們送入生成器的是隨機噪聲,而生成器理應根據給定噪聲的微小差異來生成圖像:"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"link","attrs":{"href":"https:\/\/gist.github.com\/ttchengab\/0a8b5820043c6352f5cbcb7764f2eb62#file-traingan-py","title":"","type":null},"content":[{"type":"text","text":"trainGAN.py"}]}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"codeblock","attrs":{"lang":"python"},"content":[{"type":"text","text":"\"\"\"\nNetwork training procedure\nEvery step both the loss for disciminator and generator is updated\nDiscriminator aims to classify reals and fakes\nGenerator aims to generate images as realistic as possible\n\"\"\"\nfor epoch in range(epochs):\n for idx, (imgs, _) in enumerate(train_loader):\n idx += 1\n\n # Training the discriminator\n # Real inputs are actual images of the MNIST dataset\n # Fake inputs are from the generator\n # Real inputs should be classified as 1 and fake as 0\n real_inputs = imgs.to(device)\n real_outputs = D(real_inputs)\n real_label = torch.ones(real_inputs.shape[0], 1).to(device)\n\n noise = (torch.rand(real_inputs.shape[0], 128) - 0.5) \/ 0.5\n noise = noise.to(device)\n fake_inputs = G(noise)\n fake_outputs = D(fake_inputs)\n fake_label = torch.zeros(fake_inputs.shape[0], 1).to(device)\n\n outputs = torch.cat((real_outputs, fake_outputs), 0)\n targets = torch.cat((real_label, fake_label), 0)\n\n D_loss = loss(outputs, targets)\n D_optimizer.zero_grad()\n D_loss.backward()\n D_optimizer.step()\n\n # Training the generator\n # For generator, goal is to make the discriminator believe everything is 1\n noise = (torch.rand(real_inputs.shape[0], 128)-0.5)\/0.5\n noise = noise.to(device)\n\n fake_inputs = G(noise)\n fake_outputs = D(fake_inputs)\n fake_targets = torch.ones([fake_inputs.shape[0], 1]).to(device)\n G_loss = loss(fake_outputs, fake_targets)\n G_optimizer.zero_grad()\n G_loss.backward()\n G_optimizer.step()\n\n if idx % 100 == 0 or idx == len(train_loader):\n print('Epoch {} Iteration {}: discriminator_loss {:.3f} generator_loss {:.3f}'.format(epoch, idx, D_loss.item(), G_loss.item()))\n\n if (epoch+1) % 10 == 0:\n torch.save(G, 'Generator_epoch_{}.pth'.format(epoch))\n print('Model saved.')\n"}]},{"type":"heading","attrs":{"align":null,"level":2},"content":[{"type":"text","text":"結果"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"當 100 個輪數(epoch)之後,我們可以繪製數據集,並看到從隨機噪音中生成的數字的結果:"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"image","attrs":{"src":"https:\/\/static001.geekbang.org\/resource\/image\/20\/6f\/204de0dfa951e742f792685e2f0d0b6f.jpg","alt":null,"title":null,"style":[{"key":"width","value":"75%"},{"key":"bordertype","value":"none"}],"href":null,"fromPaste":true,"pastePass":true}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":"center","origin":null},"content":[{"type":"text","marks":[{"type":"size","attrs":{"size":10}}],"text":"圖2:GAN生成的結"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"如上圖所示,生成的結果看起來確實相當像真實的結果。鑑於網絡非常簡單,所以結果看起來確實很有希望!"}]},{"type":"heading","attrs":{"align":null,"level":2},"content":[{"type":"text","text":"超越單純的內容創作"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"GAN 的創造與計算機視覺領域的先前工作如此不同。隨後的衆多應用使學術界對深度網絡的能力感到驚訝。下面將介紹一些令人驚訝的工作。"}]},{"type":"heading","attrs":{"align":null,"level":3},"content":[{"type":"text","text":"CycleGAN"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"Zhu 等人的 CycleGAN 引入了一種概念,它無需配對樣本就可以將圖像從 X 域翻譯成 Y 域。馬被轉化爲斑馬,夏日的陽光被轉化爲暴風雪,CycleGAN 的結果令人驚訝且準確。"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"image","attrs":{"src":"https:\/\/static001.geekbang.org\/resource\/image\/fb\/f4\/fb49b97a149f0f6753d4cefa9c8763f4.jpg","alt":null,"title":null,"style":[{"key":"width","value":"75%"},{"key":"bordertype","value":"none"}],"href":null,"fromPaste":true,"pastePass":true}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":"center","origin":null},"content":[{"type":"text","marks":[{"type":"color","attrs":{"color":"#494949","name":"user"}}],"text":"3:Zhu 等人的 CycleGAN 生成的結果。"}]},{"type":"heading","attrs":{"align":null,"level":3},"content":[{"type":"text","text":"GauGAN"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"Nvidia 利用 GAN 的力量,把簡單的繪畫,根據畫筆的語義,轉換成優雅而逼真的照片。儘管訓練資源的計算成本很高,但它創造了一個全新的研究和應用領域。"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"image","attrs":{"src":"https:\/\/static001.geekbang.org\/resource\/image\/2f\/0e\/2fdd8yybafcb55ae2266ea392969540e.jpg","alt":null,"title":null,"style":[{"key":"width","value":"75%"},{"key":"bordertype","value":"none"}],"href":null,"fromPaste":true,"pastePass":true}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":"center","origin":null},"content":[{"type":"text","marks":[{"type":"color","attrs":{"color":"#494949","name":"user"}}],"text":"4:GaoGAN 的生成結果。左爲原圖,右爲生成的結果。"}]},{"type":"heading","attrs":{"align":null,"level":3},"content":[{"type":"text","text":"AdvGAN"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"GAN 還擴展到清理對抗性圖像,並將其轉化爲不會欺騙分類器的乾淨樣本。關於對抗性攻擊和防禦的更多信息可以在"},{"type":"link","attrs":{"href":"https:\/\/towardsdatascience.com\/adversarial-attack-and-defense-on-neural-networks-in-pytorch-82b5bcd9171","title":"","type":null},"content":[{"type":"text","text":"這裏"}]},{"type":"text","text":"到。"}]},{"type":"heading","attrs":{"align":null,"level":2},"content":[{"type":"text","text":"結語"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"所以,你已經擁有了它!希望這篇文章對如何構建 GAN 提供了一個概覽。完整的實現可以在下面的 Github 資源庫中找到:"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"https:\/\/github.com\/ttchengab\/MnistGAN"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"strong"}],"text":"作者簡介:"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"Ta-ying Cheng,中國香港人,牛津大學哲學博士新生,愛好 3D 視覺、深度學習。"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"strong"}],"text":"原文鏈接:"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"color","attrs":{"color":"#494949","name":"user"}}],"text":"https:\/\/towardsdatascience.com\/building-a-gan-with-pytorch-237b4b07ca9a"}]}]}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章