一、這是個Pytorch學習案例,可以根據這個案例寫自己的模型
二、代碼
1、導入相關模塊
import torch
from torch import nn
import torchvision
import numpy as np
import cv2
%matplotlib inline
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
import random
import copy
import torch.optim as optim
2、定義網絡模型,這是一個迴歸模型,用於過濾高斯分佈的噪聲,從而復原分佈,這個模型定義的比較簡單
class myNet(nn.Module):
def __init__(self,):
super(myNet, self).__init__()
self.conv1=nn.Conv2d(1,3,kernel_size=3,padding=1)
self.conv2=nn.Conv2d(3,1,kernel_size=3,padding=1)
self.relu=nn.ReLU(inplace=True)
def forward(self,x):
x=self.conv1(x)
x=self.relu(x)
x=self.conv2(x)
return x
3、數據生成與測試
class Datagen(Dataset):
def __init__(self,size=12,transform=None,sigma=3):
self.size=size
self.transform=transform
self.db=[]
self.sigma=sigma
for i in range(10):
x=np.arange(0,self.size,1,np.float32)
y=x[:,np.newaxis]
#template=np.zeros((15,15))
template=np.exp(-((x-random.randint(0,self.size))**2+(y-random.randint(0,self.size))**2)/(2*self.sigma))
data_noisy = template + 0.2*np.random.normal(size=template.shape)
#self.db.append([data_noisy,np.exp(-((x-self.size//2)**2+(y-self.size//2)**2)/(2*self.sigma))])
self.db.append([data_noisy,template[None,:,:]])
def __len__(self,):
return len(self.db)
def __getitem__(self, idx):
db_rec = copy.deepcopy(self.db[idx])
db_rec[0]=self.transform(db_rec[0]).float()
#data=torch.from_numpy(db_rec)
return db_rec
transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
#這是pytorch自帶的數據轉換爲tensor的函數,這個庫中還包含了對於圖像的數據增廣函數,很方便
#數據可視化
trainData[1][0].size()
plt.imshow(trainData[2][0][0,:])
plt.show()
4、定義loss以及數據初始化、因爲是迴歸模型,通常會採用l2作爲loss
#定義loss
criterion = nn.MSELoss(size_average=True).cuda()
#訓練數據初始化
trainData=Datagen(size=166,transform=transform,sigma=15)
train_loader = torch.utils.data.DataLoader(
trainData,
batch_size=1,
shuffle=True,
num_workers=2,
pin_memory=True
)
5、定義優化器以及初始化網絡模型
#網絡初始化
net=myNet()
model=net.cuda()
#定義優化器
optimizer = optim.Adam(model.parameters(),lr=0.001 )
6、訓練迭代器用for循環即可,這裏兩層循環分別是,epoch和iteration
#開始迭代訓練
epoch=50
for i in range(epoch):
sum_loss=0
for i, data in enumerate(train_loader):
input_data,target=data
input_data=input_data.cuda(non_blocking=True)
target=target.cuda(non_blocking=True)
output = model(input_data)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
sum_loss+=loss.item()
msg='Loss:{loss:.4f}'.format(loss=sum_loss/10.)
print(msg)
7、模型的保存
torch.save(model.state_dict(), PATH)
8、編輯測試案例
#寫個測試例子
testData=Datagen(size=166,transform=transform,sigma=15)
test_loader = torch.utils.data.DataLoader(
testData,
batch_size=1,
shuffle=False,
num_workers=2,
pin_memory=True
)
model.eval()
for i,tData in enumerate(train_loader):
if i >=1:
break
test_input=tData[0].cuda(non_blocking=True)
test_target=tData[1].cuda(non_blocking=True)
pred_test=model(test_input)
plt.imshow(test_input[0,0,:,:].cpu().detach().numpy())
plt.show()
plt.imshow(pred_test[0,0,:,:].cpu().detach().numpy())
plt.show()
plt.imshow(test_target[0,0,:,:].cpu().detach().numpy())
plt.show()