入坑生成對抗網絡(GAN)

看了以後感覺還行的關於生成對抗網絡的一個GitHub項目,安利給大家
鏈接:https://github.com/kwotsin/mimicry
文檔:https://mimicry.readthedocs.io/en/latest/guides/introduction.html
比較不錯的參考鏈接:https://www.cnblogs.com/wanghui-garcia/p/10785579.html
這個項目時以python包的形式發佈了的,直接可以用pip安裝pip install torch-mimicry
安裝完了就可以用下面的代碼開始訓練了,簡直不要太容易。默認是訓練cifar10數據的
代碼官網有,感覺很簡單,先放一下,後面更新可能會改
train.py

import torch
import torch.optim as optim
import torch_mimicry as mmc
from torch_mimicry.nets import sngan

# Data handling objects
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
dataset = mmc.datasets.load_dataset(root='./datasets', name='cifar10')
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=64, shuffle=True, num_workers=0)

# Define models and optimizers
netG = sngan.SNGANGenerator32().to(device)
netD = sngan.SNGANDiscriminator32().to(device)
optD = optim.Adam(netD.parameters(), 2e-4, betas=(0.0, 0.9))
optG = optim.Adam(netG.parameters(), 2e-4, betas=(0.0, 0.9))

# Start training
trainer = mmc.training.Trainer(
    netD=netD,
    netG=netG,
    optD=optD,
    optG=optG,
    n_dis=5,
    num_steps=100000,
    lr_decay='linear',
    dataloader=dataloader,
    log_dir='./log/example',
    device=device)
trainer.train()

# Evaluate fid
mmc.metrics.evaluate(
    metric='fid',
    log_dir='./log/example',
    netG=netG,
    dataset_name='cifar10',
    num_real_samples=50000,
    num_fake_samples=50000,
    evaluate_step=100000,
    device=device)

# Evaluate kid
mmc.metrics.evaluate(
    metric='kid',
    log_dir='./log/example',
    netG=netG,
    dataset_name='cifar10',
    num_subsets=50,
    subset_size=1000,
    evaluate_step=100000,        
    device=device)

# Evaluate inception score
mmc.metrics.evaluate(
    metric='inception_score',
    log_dir='./log/example',
    netG=netG,
    num_samples=50000,
    evaluate_step=100000,        
    device=device)

這裏會有一個bug,就是訓練到50次的時候回提示一個缺少路徑的bug,缺少紅線標的errD還是errG我忘了,最好還是手動新建一下吧
錯誤說明

接下來根據以上第三個鏈接給出的練習,我試着訓練鏈接裏提供的卡通頭像,訓練腳本需要改的地方其實不多,只需要改數據讀入的地方就行了,更改後的代碼如下:

import torch
import torch.optim as optim
import torch_mimicry as mmc
from torch_mimicry.nets import sngan
import torchvision as tv

# Data handling objects
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
image_size = 96
data_path = "./own_data"  #這裏注意一下,裏面放的是一個文件夾face,即./own_data/face,face裏面就是所有圖像了
batch_size = 256
num_workers = 0
transforms = tv.transforms.Compose([
        tv.transforms.Resize(image_size),
        tv.transforms.CenterCrop(image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

dataset = tv.datasets.ImageFolder(data_path, transform=transforms)
dataloader = torch.utils.data.DataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=num_workers,
                                     drop_last=True
                                     )


# Define models and optimizers
netG = sngan.SNGANGenerator32().to(device)
netD = sngan.SNGANDiscriminator32().to(device)
optD = optim.Adam(netD.parameters(), 2e-4, betas=(0.0, 0.9))
optG = optim.Adam(netG.parameters(), 2e-4, betas=(0.0, 0.9))

# Start training
trainer = mmc.training.Trainer(
    netD=netD,
    netG=netG,
    optD=optD,
    optG=optG,
    n_dis=5,
    num_steps=100000,
    lr_decay='linear',
    dataloader=dataloader,
    log_dir='./log/example',
    device=device)
trainer.train()

# Evaluate fid
mmc.metrics.evaluate(
    metric='fid',
    log_dir='./log/example',
    netG=netG,
    dataset_name='cifar10',
    num_real_samples=50000,
    num_fake_samples=50000,
    evaluate_step=100000,
    device=device)

# Evaluate kid
mmc.metrics.evaluate(
    metric='kid',
    log_dir='./log/example',
    netG=netG,
    dataset_name='cifar10',
    num_subsets=50,
    subset_size=1000,
    evaluate_step=100000,        
    device=device)

# Evaluate inception score
mmc.metrics.evaluate(
    metric='inception_score',
    log_dir='./log/example',
    netG=netG,
    num_samples=50000,
    evaluate_step=100000,        
    device=device)

感覺效果不會很好的樣子,未完待續。。。

中間訓練過程出現了問題,有參數一直不變,有參數就一直在很小範圍內波動,求助啊,希望有人指導一下啊,我訓練放了兩個類別,一個是人臉,一個是卡通臉
訓練過程
數據存放目錄
數據結構

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章