自編碼AE 實現圖片去馬賽克 pytorch

去年老早,曾經寫過一個AE的實現,不過寫的比較墨跡,不夠成熟。今天看到了,就重新寫一個。

一.代碼

1.全代碼名稱展示

2.主程序

(一).訓練階段

(1).dataset.py

import torch
import os
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image

class GetData(Dataset):
    def __init__(self,path0,path1): #得到名字list
        super(GetData,self).__init__()
        self.path0 = path0
        self.path1 = path1
        self.name0_list = os.listdir(self.path0)
        self.name1_list = os.listdir(self.path1)
        self.img2data = transforms.Compose([transforms.ToTensor()])

    def __len__(self):
        return len(self.name0_list)

    def __getitem__(self, index): #按名取圖,index對應批次
        self.name0 = self.name0_list[index]
        self.name1 = self.name1_list[index]
        img0 = Image.open(os.path.join(self.path0, self.name0))
        img1 = Image.open(os.path.join(self.path1, self.name1))
        imgdata0 = self.img2data(img0)
        imgdata1 = self.img2data(img1)

        return imgdata0, imgdata1

(2).net.py

import torch
import torch.nn as nn

#卷積
class ResConv2d(torch.nn.Module):

    def __init__(self, in_channels, out_channels):
        super(ResConv2d, self).__init__()

        self.sub_net = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, in_channels, 1, 1, 0), #利用1x1網絡
            torch.nn.BatchNorm2d(in_channels),
            torch.nn.PReLU(),

            torch.nn.Conv2d(in_channels, in_channels, 3, 1, 1),
            torch.nn.BatchNorm2d(in_channels),
            torch.nn.PReLU(),

            torch.nn.Conv2d(in_channels, in_channels, 1, 1, 0),
            torch.nn.BatchNorm2d(in_channels),
            torch.nn.PReLU(),
        )

        self.down_net = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, 4, 2, 1),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.PReLU()
        )

    def forward(self, x):
        y = self.sub_net(x)
        return self.down_net(x + y) #加殘差

#反捲積
class ResConvTranspose2d(torch.nn.Module):

    def __init__(self, in_channels, out_channels):
        super(ResConvTranspose2d, self).__init__()

        self.sub_net = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(in_channels, in_channels, 1, 1, 0),
            torch.nn.BatchNorm2d(in_channels),
            torch.nn.PReLU(),

            torch.nn.ConvTranspose2d(in_channels, in_channels, 3, 1, 1),
            torch.nn.BatchNorm2d(in_channels),
            torch.nn.PReLU(),

            torch.nn.ConvTranspose2d(in_channels, in_channels, 1, 1, 0),
            torch.nn.BatchNorm2d(in_channels),
            torch.nn.PReLU(),
        )

        self.up_net = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.PReLU(),
        )

    def forward(self, x):
        y = self.sub_net(x)
        return self.up_net(x + y)

#編碼
class EncoderNet(torch.nn.Module):
    def __init__(self):
        super(EncoderNet, self).__init__()

        self.sub_net = torch.nn.Sequential(
            ResConv2d(3, 64),  # 32
            ResConv2d(64, 128),  # 16
            ResConv2d(128, 256),  # 8
            ResConv2d(256, 512),  # 4
            ResConv2d(512, 1024),  # 2
            ResConv2d(1024, 20)  # 1
        )

    def forward(self, x):
        return self.sub_net(x)

#解碼
class DecoderNet(torch.nn.Module):
    def __init__(self):
        super(DecoderNet, self).__init__()

        self.decorder = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(20, 1024, 4, 1, 0),
            ResConvTranspose2d(1024, 512),  # 4
            ResConvTranspose2d(512, 256),  # 8
            ResConvTranspose2d(256, 128),  # 16
            torch.nn.ConvTranspose2d(128, 3, 4, 2, 1)  # 64
        )

    def forward(self,x):
        return self.decorder(x)

(3).train.py

import torch
import net
import dataset
import torch.nn as nn
import os
import shutil
from torch.utils.data import DataLoader
from torchvision.utils import save_image

loss_f = nn.MSELoss()
class MainNet(nn.Module): 
    def __init__(self):
        super(MainNet,self).__init__()

        self.encoder = net.EncoderNet()
        self.decoder = net.DecoderNet()

    def forward(self,x1):
        y = self.encoder(x1)
        y_ = self.decoder(y)
        return y_

    def AELoss(self, y_, x0):
        return loss_f(y_, x0)

#訓練
class Trainer(nn.Module):
    def __init__(self):
        super(Trainer,self).__init__()

        self.main_net = MainNet()
        self.main_net.cuda()

        '涉及2種損失,自然就會有對應2個優化器做反向傳播'
        ae_parameters = []
        ae_parameters.extend(self.main_net.encoder.parameters())
        ae_parameters.extend(self.main_net.decoder.parameters())
        self.opt_ae = torch.optim.Adam(ae_parameters, lr=1e-3)

    def train(self):
        for epoch in range(10000):
            if os.path.exists('./param0/encoder.pkl'):
                self.main_net.encoder.load_state_dict(torch.load('./param0/encoder.pkl'))
            if os.path.exists('./param0/decoder.pkl'):
                self.main_net.decoder.load_state_dict(torch.load('./param0/decoder.pkl'))

            self.dataloader = DataLoader(dataset.GetData(path0=r'C:\Users\87419\Desktop\data\64',
                             path1=r'C:\Users\87419\Desktop\data\64_dama'), batch_size=128, shuffle=True)
            count = 0

            '每個epoch內都是遍歷5萬張圖,即dataloader數。每count一次,即每次循環都是處理batchsize張'
            'dataloader長度 = 總張數/批次數 :782 = 50000/64。即loader長度等於每個ecpoch的總count數'
            for img0data, img1data in self.dataloader:

                img0data = img0data.cuda()#把輸入的數據加cuda,接下來裏面的過程數據自然也就以cuda運行
                img1data = img1data.cuda()

                count += 1
                # print('/////////////////////')
                # print(len(self.dataloader))

                self.main_net.train()#訓練模式
                '每種做梯度更新反向傳播,都要重新加載數據!!!'
                y_ = self.main_net(img1data)
                # 生成器VAE損失更新
                aeloss = self.main_net.AELoss(y_, img0data)
                self.opt_ae.zero_grad()
                aeloss.backward()
                self.opt_ae.step()

                if count%25 == 0:
                    self.main_net.eval() #測試模式
                    if os.path.exists('./param0/encoder_tmp.pkl'):
                        shutil.copyfile('./param0/encoder_tmp.pkl', './param0/encoder.pkl')
                    torch.save(self.main_net.encoder.state_dict(), './param0/encoder.pkl')
                    if os.path.exists('./param0/decoder_tmp.pkl'):
                        shutil.copyfile('./param0/decoder_tmp.pkl', './param0/decoder.pkl')
                    torch.save(self.main_net.decoder.state_dict(), './param0/decoder.pkl')

                    save_image(img0data[:1],'./result0/{}_{}_0.jpg'.format(epoch, count))#原圖
                    save_image(img1data[:1],'./result0/{}_{}_1.jpg'.format(epoch, count))#遮擋圖
                    save_image(y_[:1],'./result0/{}_{}_1_0.jpg'.format(epoch, count)) #生成器vae的輸出

                    print('epoch:',epoch,'|','count:',count,'|','|','aeloss:',aeloss.item()/len(self.dataloader))

if __name__ == '__main__':
    Trainer().train()

(二).測試階段

(1).dataset_test.py

import torch
import os
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image

class GetData(Dataset):
    def __init__(self,path0): #得到名字list
        super(GetData,self).__init__()
        self.path0 = path0
        self.name0_list = os.listdir(self.path0)
        self.img2data = transforms.Compose([transforms.ToTensor()])

    def __len__(self):
        return len(self.name0_list)

    def __getitem__(self, index): #按名取圖,index對應批次
        self.name0 = self.name0_list[index]
        img0 = Image.open(os.path.join(self.path0, self.name0))
        imgdata0 = self.img2data(img0)

        return imgdata0

(2).test.py

import torch
import net
import dataset_test
import torch.nn as nn
import os
from torch.utils.data import DataLoader
from torchvision.utils import save_image

class MainNet(nn.Module):
    def __init__(self):
        super(MainNet,self).__init__()
        self.encoder = net.EncoderNet()
        self.decoder = net.DecoderNet()

    def forward(self,x1):
        y = self.encoder(x1)
        y_ = self.decoder(y)
        return y_

class Test(nn.Module):
    def __init__(self):
        super(Test,self).__init__()

        self.main_net = MainNet()
        self.main_net.cuda()

    def test(self):
        if os.path.exists('./param0/encoder.pkl'):
            self.main_net.encoder.load_state_dict(torch.load('./param0/encoder.pkl'))
        if os.path.exists('./param0/decoder.pkl'):
            self.main_net.decoder.load_state_dict(torch.load('./param0/decoder.pkl'))

        self.dataloader = DataLoader(dataset_test.GetData(path0=r'C:\Users\87419\Desktop\data\test'))
        count = 0
        self.main_net.eval()  # 測試模式
        for img0data in self.dataloader:
            img0data = img0data.cuda()
            encoded = self.main_net.encoder(img0data)
            decoded = self.main_net.decoder(encoded)
            count += 1
            save_image(decoded, r'C:\Users\87419\Desktop\data\AE_test_result/{}.jpg'.format(count))

if __name__ == '__main__':
    Test().test()

測試效果如下(沒仔細訓練,只是意思一下):

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章