去年老早,曾經寫過一個AE的實現,不過寫的比較墨跡,不夠成熟。今天看到了,就重新寫一個。
一.代碼
1.全代碼名稱展示
2.主程序
(一).訓練階段
(1).dataset.py
import torch
import os
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
class GetData(Dataset):
def __init__(self,path0,path1): #得到名字list
super(GetData,self).__init__()
self.path0 = path0
self.path1 = path1
self.name0_list = os.listdir(self.path0)
self.name1_list = os.listdir(self.path1)
self.img2data = transforms.Compose([transforms.ToTensor()])
def __len__(self):
return len(self.name0_list)
def __getitem__(self, index): #按名取圖,index對應批次
self.name0 = self.name0_list[index]
self.name1 = self.name1_list[index]
img0 = Image.open(os.path.join(self.path0, self.name0))
img1 = Image.open(os.path.join(self.path1, self.name1))
imgdata0 = self.img2data(img0)
imgdata1 = self.img2data(img1)
return imgdata0, imgdata1
(2).net.py
import torch
import torch.nn as nn
#卷積
class ResConv2d(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super(ResConv2d, self).__init__()
self.sub_net = torch.nn.Sequential(
torch.nn.Conv2d(in_channels, in_channels, 1, 1, 0), #利用1x1網絡
torch.nn.BatchNorm2d(in_channels),
torch.nn.PReLU(),
torch.nn.Conv2d(in_channels, in_channels, 3, 1, 1),
torch.nn.BatchNorm2d(in_channels),
torch.nn.PReLU(),
torch.nn.Conv2d(in_channels, in_channels, 1, 1, 0),
torch.nn.BatchNorm2d(in_channels),
torch.nn.PReLU(),
)
self.down_net = torch.nn.Sequential(
torch.nn.Conv2d(in_channels, out_channels, 4, 2, 1),
torch.nn.BatchNorm2d(out_channels),
torch.nn.PReLU()
)
def forward(self, x):
y = self.sub_net(x)
return self.down_net(x + y) #加殘差
#反捲積
class ResConvTranspose2d(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super(ResConvTranspose2d, self).__init__()
self.sub_net = torch.nn.Sequential(
torch.nn.ConvTranspose2d(in_channels, in_channels, 1, 1, 0),
torch.nn.BatchNorm2d(in_channels),
torch.nn.PReLU(),
torch.nn.ConvTranspose2d(in_channels, in_channels, 3, 1, 1),
torch.nn.BatchNorm2d(in_channels),
torch.nn.PReLU(),
torch.nn.ConvTranspose2d(in_channels, in_channels, 1, 1, 0),
torch.nn.BatchNorm2d(in_channels),
torch.nn.PReLU(),
)
self.up_net = torch.nn.Sequential(
torch.nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1),
torch.nn.BatchNorm2d(out_channels),
torch.nn.PReLU(),
)
def forward(self, x):
y = self.sub_net(x)
return self.up_net(x + y)
#編碼
class EncoderNet(torch.nn.Module):
def __init__(self):
super(EncoderNet, self).__init__()
self.sub_net = torch.nn.Sequential(
ResConv2d(3, 64), # 32
ResConv2d(64, 128), # 16
ResConv2d(128, 256), # 8
ResConv2d(256, 512), # 4
ResConv2d(512, 1024), # 2
ResConv2d(1024, 20) # 1
)
def forward(self, x):
return self.sub_net(x)
#解碼
class DecoderNet(torch.nn.Module):
def __init__(self):
super(DecoderNet, self).__init__()
self.decorder = torch.nn.Sequential(
torch.nn.ConvTranspose2d(20, 1024, 4, 1, 0),
ResConvTranspose2d(1024, 512), # 4
ResConvTranspose2d(512, 256), # 8
ResConvTranspose2d(256, 128), # 16
torch.nn.ConvTranspose2d(128, 3, 4, 2, 1) # 64
)
def forward(self,x):
return self.decorder(x)
(3).train.py
import torch
import net
import dataset
import torch.nn as nn
import os
import shutil
from torch.utils.data import DataLoader
from torchvision.utils import save_image
loss_f = nn.MSELoss()
class MainNet(nn.Module):
def __init__(self):
super(MainNet,self).__init__()
self.encoder = net.EncoderNet()
self.decoder = net.DecoderNet()
def forward(self,x1):
y = self.encoder(x1)
y_ = self.decoder(y)
return y_
def AELoss(self, y_, x0):
return loss_f(y_, x0)
#訓練
class Trainer(nn.Module):
def __init__(self):
super(Trainer,self).__init__()
self.main_net = MainNet()
self.main_net.cuda()
'涉及2種損失,自然就會有對應2個優化器做反向傳播'
ae_parameters = []
ae_parameters.extend(self.main_net.encoder.parameters())
ae_parameters.extend(self.main_net.decoder.parameters())
self.opt_ae = torch.optim.Adam(ae_parameters, lr=1e-3)
def train(self):
for epoch in range(10000):
if os.path.exists('./param0/encoder.pkl'):
self.main_net.encoder.load_state_dict(torch.load('./param0/encoder.pkl'))
if os.path.exists('./param0/decoder.pkl'):
self.main_net.decoder.load_state_dict(torch.load('./param0/decoder.pkl'))
self.dataloader = DataLoader(dataset.GetData(path0=r'C:\Users\87419\Desktop\data\64',
path1=r'C:\Users\87419\Desktop\data\64_dama'), batch_size=128, shuffle=True)
count = 0
'每個epoch內都是遍歷5萬張圖,即dataloader數。每count一次,即每次循環都是處理batchsize張'
'dataloader長度 = 總張數/批次數 :782 = 50000/64。即loader長度等於每個ecpoch的總count數'
for img0data, img1data in self.dataloader:
img0data = img0data.cuda()#把輸入的數據加cuda,接下來裏面的過程數據自然也就以cuda運行
img1data = img1data.cuda()
count += 1
# print('/////////////////////')
# print(len(self.dataloader))
self.main_net.train()#訓練模式
'每種做梯度更新反向傳播,都要重新加載數據!!!'
y_ = self.main_net(img1data)
# 生成器VAE損失更新
aeloss = self.main_net.AELoss(y_, img0data)
self.opt_ae.zero_grad()
aeloss.backward()
self.opt_ae.step()
if count%25 == 0:
self.main_net.eval() #測試模式
if os.path.exists('./param0/encoder_tmp.pkl'):
shutil.copyfile('./param0/encoder_tmp.pkl', './param0/encoder.pkl')
torch.save(self.main_net.encoder.state_dict(), './param0/encoder.pkl')
if os.path.exists('./param0/decoder_tmp.pkl'):
shutil.copyfile('./param0/decoder_tmp.pkl', './param0/decoder.pkl')
torch.save(self.main_net.decoder.state_dict(), './param0/decoder.pkl')
save_image(img0data[:1],'./result0/{}_{}_0.jpg'.format(epoch, count))#原圖
save_image(img1data[:1],'./result0/{}_{}_1.jpg'.format(epoch, count))#遮擋圖
save_image(y_[:1],'./result0/{}_{}_1_0.jpg'.format(epoch, count)) #生成器vae的輸出
print('epoch:',epoch,'|','count:',count,'|','|','aeloss:',aeloss.item()/len(self.dataloader))
if __name__ == '__main__':
Trainer().train()
(二).測試階段
(1).dataset_test.py
import torch
import os
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
class GetData(Dataset):
def __init__(self,path0): #得到名字list
super(GetData,self).__init__()
self.path0 = path0
self.name0_list = os.listdir(self.path0)
self.img2data = transforms.Compose([transforms.ToTensor()])
def __len__(self):
return len(self.name0_list)
def __getitem__(self, index): #按名取圖,index對應批次
self.name0 = self.name0_list[index]
img0 = Image.open(os.path.join(self.path0, self.name0))
imgdata0 = self.img2data(img0)
return imgdata0
(2).test.py
import torch
import net
import dataset_test
import torch.nn as nn
import os
from torch.utils.data import DataLoader
from torchvision.utils import save_image
class MainNet(nn.Module):
def __init__(self):
super(MainNet,self).__init__()
self.encoder = net.EncoderNet()
self.decoder = net.DecoderNet()
def forward(self,x1):
y = self.encoder(x1)
y_ = self.decoder(y)
return y_
class Test(nn.Module):
def __init__(self):
super(Test,self).__init__()
self.main_net = MainNet()
self.main_net.cuda()
def test(self):
if os.path.exists('./param0/encoder.pkl'):
self.main_net.encoder.load_state_dict(torch.load('./param0/encoder.pkl'))
if os.path.exists('./param0/decoder.pkl'):
self.main_net.decoder.load_state_dict(torch.load('./param0/decoder.pkl'))
self.dataloader = DataLoader(dataset_test.GetData(path0=r'C:\Users\87419\Desktop\data\test'))
count = 0
self.main_net.eval() # 測試模式
for img0data in self.dataloader:
img0data = img0data.cuda()
encoded = self.main_net.encoder(img0data)
decoded = self.main_net.decoder(encoded)
count += 1
save_image(decoded, r'C:\Users\87419\Desktop\data\AE_test_result/{}.jpg'.format(count))
if __name__ == '__main__':
Test().test()
測試效果如下(沒仔細訓練,只是意思一下):