最近在做科研上的項目,需要調各種GAN的模型,鑑於網上各種拿着標準數據集跑模型的流氓行爲,本人決定推出一種對各種數據集都適用的模型訓練教程。
話不多說,先上代碼,大家看着我的代碼,加上我的講解,相信所有人都能無痛調節模型的參數。
我用的是github上PyTorch-GAN的代碼,這個github實現了很多種類的GAN,並且寫出來的模型也不復雜,很適合小白。然後我調的是DCGAN
DCGAN- 模型代碼
class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.init_size = opt.img_size // 4 self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128*self.init_size**2)) self.conv_blocks = nn.Sequential( nn.BatchNorm2d(128), nn.Upsample(scale_factor=2), nn.Conv2d(128, 128, 3, stride=1, padding=1), nn.BatchNorm2d(128, 0.8), nn.LeakyReLU(0.2, inplace=True), nn.Upsample(scale_factor=2), nn.Conv2d(128, 64, 3, stride=1, padding=1), nn.BatchNorm2d(64, 0.8), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(64, opt.channels, 3, stride=1, padding=1), nn.Tanh() ) def forward(self, z): out = self.l1(z) out = out.view(out.shape[0], 128, self.init_size, self.init_size) img = self.conv_blocks(out) return img class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() def discriminator_block(in_filters, out_filters, bn=True): block = [ nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)] if bn: block.append(nn.BatchNorm2d(out_filters, 0.8)) return block self.model = nn.Sequential( *discriminator_block(opt.channels, 16, bn=False), *discriminator_block(16, 32), *discriminator_block(32, 64), *discriminator_block(64, 128), ) # The height and width of downsampled image ds_size = opt.img_size // 2**4 self.adv_layer = nn.Sequential( nn.Linear(128*ds_size**2, 1), nn.Sigmoid()) def forward(self, img): out = self.model(img) out = out.view(out.shape[0], -1) validity = self.adv_layer(out) return validity
上面的代碼是DCGAN
最核心的部分。是由兩個網絡組成:生成網絡和判別網絡。好了,接下來我一點點的解釋這些代碼。
Generator生成網絡代碼
class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.init_size = opt.img_size // 4 self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128*self.init_size**2)) self.conv_blocks = nn.Sequential( nn.BatchNorm2d(128), nn.Upsample(scale_factor=2), nn.Conv2d(128, 128, 3, stride=1, padding=1), nn.BatchNorm2d(128, 0.8), nn.LeakyReLU(0.2, inplace=True), nn.Upsample(scale_factor=2), nn.Conv2d(128, 64, 3, stride=1, padding=1), nn.BatchNorm2d(64, 0.8), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(64, opt.channels, 3, stride=1, padding=1), nn.Tanh() ) def forward(self, z): out = self.l1(z) out = out.view(out.shape[0], 128, self.init_size, self.init_size) img = self.conv_blocks(out) return img
看這個生成網絡的代碼,需要從def forward(self,z):
看起。這個是數據真正被處理的流程。
這個處理的順序是:
-
out = self.l1(z)
-
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
-
img = self.conv_blocks(out)
-
return img
第一個語句執行的是l1
函數,這個函數在上面的class
裏面定義好了:self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128*self.init_size**2))
這個語句的意思就是l1
函數進行的是Linear
變換。這個線性變換的兩個參數是變換前的維度,和變換之後的維度。博主建議大家一個學習使用這些函數的方法:如果你使用的是pycharm 就可以選中你想了解的函數,然後按下ctrl + B 就可以跳轉到該函數的定義處,一般在這個函數裏都會有如何使用的介紹,以及example,非常好用。
那你會問了:這個Linear
函數裏面使用的參數是 self.init_size = opt.img_size // 4
,爲什麼不是opt.img_size
呢,這個就是接下來需要說的一個上採樣
.
第二個語句執行的是view()
函數,這個函數很簡單,是一個維度變換函數,我們可以看到out
數據變成了四維數據,第一個是batch_size(通過整個的代碼,你就可以明白了)
,第二個是channel
,第三,四是單張圖片的長寬。
第三個語句執行的是self.conv_blocks(out)
函數,這個函數我們往上看,可以看到:
self.conv_blocks = nn.Sequential( nn.BatchNorm2d(128), nn.Upsample(scale_factor=2), nn.Conv2d(128, 128, 3, stride=1, padding=1), nn.BatchNorm2d(128, 0.8), nn.LeakyReLU(0.2, inplace=True), nn.Upsample(scale_factor=2), nn.Conv2d(128, 64, 3, stride=1, padding=1), nn.BatchNorm2d(64, 0.8), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(64, opt.channels, 3, stride=1, padding=1), nn.Tanh() )
nn.sequential{}
是一個組成模型的殼子,用來容納不同的操作。
我們大體可以看到這個殼子裏面是由BatchNorm2d
,Upsample
,Conv2d
,LeakyReLU
組成。
第一個是歸一化函數對數據的形狀沒影響主要就是改變數據的量綱。
第二個函數是上採樣函數,這個函數會將單張圖片的尺寸進行放大(這就是爲什麼class最先開始將圖片的長寬除了4,是因爲殼子裏面存在兩個2倍的上採樣函數
)。
第三個函數是二維卷積函數,各個參數分別是輸入數據的channel
,輸出數據的channel
,剩下的三個參數是卷積的三個參數:卷積步長,卷積核大小,padding的大小。這個二維卷積函數會對channel
的大小有影響,同時還會對單張圖片的大小有影響。卷積的計算公式$H_{out} = (H_{in}-1)* S-2*P +K $
第四個函數是一個帶有傾斜角度的激活函數,它是由ReLu
函數改造而來的。
好了生成網絡就講完了。我們再來看判別網絡:
Discriminator判別網絡代碼
class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() def discriminator_block(in_filters, out_filters, bn=True): block = [ nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)] if bn: block.append(nn.BatchNorm2d(out_filters, 0.8)) return block self.model = nn.Sequential( *discriminator_block(opt.channels, 16, bn=False), *discriminator_block(16, 32), *discriminator_block(32, 64), *discriminator_block(64, 128), ) # The height and width of downsampled image ds_size = opt.img_size // 2**4 self.adv_layer = nn.Sequential( nn.Linear(128*ds_size**2, 1), nn.Sigmoid()) def forward(self, img): out = self.model(img) out = out.view(out.shape[0], -1) validity = self.adv_layer(out) return validity
同樣地用生成網絡看代碼的順序,看這段代碼。數據處理的流程分四個步驟:
out = self.model(img)
out = out.view(out.shape[0], -1)
validity = self.adv_layer(out)
return validity
第一個語句執行的是model
函數。好,我們來看model
函數在class
裏面是如何定義的。
def discriminator_block(in_filters, out_filters, bn=True): block = [ nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)] if bn: block.append(nn.BatchNorm2d(out_filters, 0.8)) return block self.model = nn.Sequential( *discriminator_block(opt.channels, 16, bn=False), *discriminator_block(16, 32), *discriminator_block(32, 64), *discriminator_block(64, 128), )
model
函數是由四個discriminator_block
函數組成。然後我們再看discriminator_block
函數的定義:
def discriminator_block(in_filters, out_filters, bn=True): block = [ nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)] if bn: block.append(nn.BatchNorm2d(out_filters, 0.8)) return block
這個模塊是由四部分組成:conv2d
,leakyRelu
,Dropout
,BatchNorm2d
。
第一個語句:conv2d
函數,是用來卷積的
第二個語句是:'leakyRelu’函數,用來做激活函數的
第三個語句:Dropout
函數用來將部分神經元失活,進而防止過擬合
第四個語句:其實是一個判斷語句,如果bn
這個參數爲True
,那麼就需要在block
塊裏面添加上BatchNorm
的歸一化函數。
第二個語句執行的是view(out.shape[0],-1)
,這個語句是將處理之後的數據維度變成batch * N
的維度形式,然後再放到最後一個語句裏面執行。
第三個語句:self.adv_layer
函數,這個函數是由:self.adv_layer = nn.Sequential( nn.Linear(128*ds_size**2, 1),nn.Sigmoid())
就是先進行線性變換,再進行激活函數激活。其中第一個參數128*ds_size**2
中128是指model
模塊中最後一個判別模塊的最後一個參數決定的,ds_size
是由model
模塊對單張圖片的卷積效果決定的,而2次方的原因是,整個模型是選取的長寬一致的圖片。
避坑指南
有時候我們在腦子裏面想的很多設想其實和實際的情況不太一樣,比如第一個坑。
- 在MATLAB和python裏面經常遇到這樣的情況,明明讀取進去的圖片是正常的,出來就是一個白板。這個問題在這個博客《python中opencv imshow函數顯示一片白色原因》中解決的很好,會出現這個問題的原因是每個像素的數據類型出現了問題。
參考文獻
Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks