看了以後感覺還行的關於生成對抗網絡的一個GitHub項目,安利給大家
鏈接:https://github.com/kwotsin/mimicry
文檔:https://mimicry.readthedocs.io/en/latest/guides/introduction.html
比較不錯的參考鏈接:https://www.cnblogs.com/wanghui-garcia/p/10785579.html
這個項目時以python包的形式發佈了的,直接可以用pip安裝pip install torch-mimicry
安裝完了就可以用下面的代碼開始訓練了,簡直不要太容易。默認是訓練cifar10數據的
代碼官網有,感覺很簡單,先放一下,後面更新可能會改
train.py
import torch
import torch.optim as optim
import torch_mimicry as mmc
from torch_mimicry.nets import sngan
# Data handling objects
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
dataset = mmc.datasets.load_dataset(root='./datasets', name='cifar10')
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=64, shuffle=True, num_workers=0)
# Define models and optimizers
netG = sngan.SNGANGenerator32().to(device)
netD = sngan.SNGANDiscriminator32().to(device)
optD = optim.Adam(netD.parameters(), 2e-4, betas=(0.0, 0.9))
optG = optim.Adam(netG.parameters(), 2e-4, betas=(0.0, 0.9))
# Start training
trainer = mmc.training.Trainer(
netD=netD,
netG=netG,
optD=optD,
optG=optG,
n_dis=5,
num_steps=100000,
lr_decay='linear',
dataloader=dataloader,
log_dir='./log/example',
device=device)
trainer.train()
# Evaluate fid
mmc.metrics.evaluate(
metric='fid',
log_dir='./log/example',
netG=netG,
dataset_name='cifar10',
num_real_samples=50000,
num_fake_samples=50000,
evaluate_step=100000,
device=device)
# Evaluate kid
mmc.metrics.evaluate(
metric='kid',
log_dir='./log/example',
netG=netG,
dataset_name='cifar10',
num_subsets=50,
subset_size=1000,
evaluate_step=100000,
device=device)
# Evaluate inception score
mmc.metrics.evaluate(
metric='inception_score',
log_dir='./log/example',
netG=netG,
num_samples=50000,
evaluate_step=100000,
device=device)
這裏會有一個bug,就是訓練到50次的時候回提示一個缺少路徑的bug,缺少紅線標的errD還是errG我忘了,最好還是手動新建一下吧
接下來根據以上第三個鏈接給出的練習,我試着訓練鏈接裏提供的卡通頭像,訓練腳本需要改的地方其實不多,只需要改數據讀入的地方就行了,更改後的代碼如下:
import torch
import torch.optim as optim
import torch_mimicry as mmc
from torch_mimicry.nets import sngan
import torchvision as tv
# Data handling objects
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
image_size = 96
data_path = "./own_data" #這裏注意一下,裏面放的是一個文件夾face,即./own_data/face,face裏面就是所有圖像了
batch_size = 256
num_workers = 0
transforms = tv.transforms.Compose([
tv.transforms.Resize(image_size),
tv.transforms.CenterCrop(image_size),
tv.transforms.ToTensor(),
tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = tv.datasets.ImageFolder(data_path, transform=transforms)
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
drop_last=True
)
# Define models and optimizers
netG = sngan.SNGANGenerator32().to(device)
netD = sngan.SNGANDiscriminator32().to(device)
optD = optim.Adam(netD.parameters(), 2e-4, betas=(0.0, 0.9))
optG = optim.Adam(netG.parameters(), 2e-4, betas=(0.0, 0.9))
# Start training
trainer = mmc.training.Trainer(
netD=netD,
netG=netG,
optD=optD,
optG=optG,
n_dis=5,
num_steps=100000,
lr_decay='linear',
dataloader=dataloader,
log_dir='./log/example',
device=device)
trainer.train()
# Evaluate fid
mmc.metrics.evaluate(
metric='fid',
log_dir='./log/example',
netG=netG,
dataset_name='cifar10',
num_real_samples=50000,
num_fake_samples=50000,
evaluate_step=100000,
device=device)
# Evaluate kid
mmc.metrics.evaluate(
metric='kid',
log_dir='./log/example',
netG=netG,
dataset_name='cifar10',
num_subsets=50,
subset_size=1000,
evaluate_step=100000,
device=device)
# Evaluate inception score
mmc.metrics.evaluate(
metric='inception_score',
log_dir='./log/example',
netG=netG,
num_samples=50000,
evaluate_step=100000,
device=device)
感覺效果不會很好的樣子,未完待續。。。
中間訓練過程出現了問題,有參數一直不變,有參數就一直在很小範圍內波動,求助啊,希望有人指導一下啊,我訓練放了兩個類別,一個是人臉,一個是卡通臉
數據存放目錄