【簡單的Pytorch迴歸模型案例】CNN去除隨機噪聲--修復2d高斯分佈【pytorch demo】

 

一、這是個Pytorch學習案例,可以根據這個案例寫自己的模型

二、代碼

1、導入相關模塊

import torch 
from torch import nn
import torchvision
import numpy as np
import cv2
%matplotlib inline
import matplotlib.pyplot as plt 
from torch.utils.data import Dataset
import random
import copy
import torch.optim as optim

2、定義網絡模型,這是一個迴歸模型,用於過濾高斯分佈的噪聲,從而復原分佈,這個模型定義的比較簡單

class myNet(nn.Module):
    def __init__(self,):
        super(myNet, self).__init__()
        self.conv1=nn.Conv2d(1,3,kernel_size=3,padding=1)
        self.conv2=nn.Conv2d(3,1,kernel_size=3,padding=1)
        self.relu=nn.ReLU(inplace=True)
    def forward(self,x):
        x=self.conv1(x)
        x=self.relu(x)
        x=self.conv2(x)
        
        return x

3、數據生成與測試

class Datagen(Dataset):
    def __init__(self,size=12,transform=None,sigma=3):
        self.size=size
        self.transform=transform
        self.db=[]
        self.sigma=sigma
        for i in range(10):
            x=np.arange(0,self.size,1,np.float32)
            y=x[:,np.newaxis]
            #template=np.zeros((15,15))
            template=np.exp(-((x-random.randint(0,self.size))**2+(y-random.randint(0,self.size))**2)/(2*self.sigma))
            data_noisy = template + 0.2*np.random.normal(size=template.shape)
            #self.db.append([data_noisy,np.exp(-((x-self.size//2)**2+(y-self.size//2)**2)/(2*self.sigma))])
            self.db.append([data_noisy,template[None,:,:]])
            
    def __len__(self,):
        return len(self.db)
    
    def __getitem__(self, idx):
        db_rec = copy.deepcopy(self.db[idx])
        db_rec[0]=self.transform(db_rec[0]).float()
        #data=torch.from_numpy(db_rec)
        
        return db_rec
    
transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
#這是pytorch自帶的數據轉換爲tensor的函數,這個庫中還包含了對於圖像的數據增廣函數,很方便
#數據可視化
trainData[1][0].size()
plt.imshow(trainData[2][0][0,:])
plt.show()

4、定義loss以及數據初始化、因爲是迴歸模型,通常會採用l2作爲loss

#定義loss
criterion = nn.MSELoss(size_average=True).cuda()

#訓練數據初始化
trainData=Datagen(size=166,transform=transform,sigma=15)
train_loader = torch.utils.data.DataLoader(
        trainData,
        batch_size=1,
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )

5、定義優化器以及初始化網絡模型

#網絡初始化
net=myNet()
model=net.cuda()
#定義優化器
optimizer = optim.Adam(model.parameters(),lr=0.001 )

6、訓練迭代器用for循環即可,這裏兩層循環分別是,epoch和iteration

#開始迭代訓練
epoch=50
for i in range(epoch):
    sum_loss=0
    for i, data in enumerate(train_loader):
        input_data,target=data
        
        input_data=input_data.cuda(non_blocking=True)
        
        target=target.cuda(non_blocking=True)
    
    
        output = model(input_data)
    
    
    
        loss = criterion(output, target)
    
    
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        sum_loss+=loss.item()
    msg='Loss:{loss:.4f}'.format(loss=sum_loss/10.)
    print(msg)

7、模型的保存

torch.save(model.state_dict(), PATH)

8、編輯測試案例

#寫個測試例子
testData=Datagen(size=166,transform=transform,sigma=15)
test_loader = torch.utils.data.DataLoader(
        testData,
        batch_size=1,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )
model.eval()
for i,tData in enumerate(train_loader):
    if i >=1:
        break
    test_input=tData[0].cuda(non_blocking=True)
    test_target=tData[1].cuda(non_blocking=True)
    pred_test=model(test_input)
    
    plt.imshow(test_input[0,0,:,:].cpu().detach().numpy())
    plt.show()
    
    plt.imshow(pred_test[0,0,:,:].cpu().detach().numpy())
    plt.show()
    
    plt.imshow(test_target[0,0,:,:].cpu().detach().numpy())
    plt.show()
    

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章