最近看的paper裏的pytorch代碼太複雜,我之前也沒接觸過pytorch,遂決定先自己實現一個基礎的裸代碼,這樣走一遍,對跑網絡的基本流程和一些常用的基礎函數的印象會更深刻。
本文的代碼和數據主要來自https://blog.csdn.net/jiangpeng59/article/details/80189889,
附上該博主的github地址:https://github.com/JavisPeng/u_net_liver
並在自己的理解的基礎上做了一些改動,以及加了大量註釋。
如有錯誤,歡迎指出。
unet.py(實現unet網絡)
import torch.nn as nn
import torch
class DoubleConv(nn.Module):
def __init__(self,in_ch,out_ch):
super(DoubleConv,self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch,out_ch,3,padding=1),#in_ch、out_ch是通道數
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace = True),
nn.Conv2d(out_ch,out_ch,3,padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace = True)
)
def forward(self,x):
return self.conv(x)
class UNet(nn.Module):
def __init__(self,in_ch,out_ch):
super(UNet,self).__init__()
self.conv1 = DoubleConv(in_ch,64)
self.pool1 = nn.MaxPool2d(2)#每次把圖像尺寸縮小一半
self.conv2 = DoubleConv(64,128)
self.pool2 = nn.MaxPool2d(2)
self.conv3 = DoubleConv(128,256)
self.pool3 = nn.MaxPool2d(2)
self.conv4 = DoubleConv(256,512)
self.pool4 = nn.MaxPool2d(2)
self.conv5 = DoubleConv(512,1024)
#逆卷積
self.up6 = nn.ConvTranspose2d(1024,512,2,stride=2)
self.conv6 = DoubleConv(1024,512)
self.up7 = nn.ConvTranspose2d(512,256,2,stride=2)
self.conv7 = DoubleConv(512,256)
self.up8 = nn.ConvTranspose2d(256,128,2,stride=2)
self.conv8 = DoubleConv(256,128)
self.up9 = nn.ConvTranspose2d(128,64,2,stride=2)
self.conv9 = DoubleConv(128,64)
self.conv10 = nn.Conv2d(64,out_ch,1)
def forward(self,x):
c1 = self.conv1(x)
p1 = self.pool1(c1)
c2 = self.conv2(p1)
p2 = self.pool2(c2)
c3 = self.conv3(p2)
p3 = self.pool3(c3)
c4 = self.conv4(p3)
p4 = self.pool4(c4)
c5 = self.conv5(p4)
up_6 = self.up6(c5)
merge6 = torch.cat([up_6,c4],dim=1)#按維數1(列)拼接,列增加
c6 = self.conv6(merge6)
up_7 = self.up7(c6)
merge7 = torch.cat([up_7,c3],dim=1)
c7 = self.conv7(merge7)
up_8 = self.up8(c7)
merge8 = torch.cat([up_8,c2],dim=1)
c8 = self.conv8(merge8)
up_9 = self.up9(c8)
merge9 = torch.cat([up_9,c1],dim=1)
c9 = self.conv9(merge9)
c10 = self.conv10(c9)
out = nn.Sigmoid()(c10)#化成(0~1)區間
return out
dataset.py
import torch.utils.data as data
import os
import PIL.Image as Image
#data.Dataset:
#所有子類應該override__len__和__getitem__,前者提供了數據集的大小,後者支持整數索引,範圍從0到len(self)
class LiverDataset(data.Dataset):
#創建LiverDataset類的實例時,就是在調用init初始化
def __init__(self,root,transform = None,target_transform = None):#root表示圖片路徑
n = len(os.listdir(root))//2 #os.listdir(path)返回指定路徑下的文件和文件夾列表。/是真除法,//對結果取整
imgs = []
for i in range(n):
img = os.path.join(root,"%03d.png"%i)#os.path.join(path1[,path2[,......]]):將多個路徑組合後返回
mask = os.path.join(root,"%03d_mask.png"%i)
imgs.append([img,mask])#append只能有一個參數,加上[]變成一個list
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
def __getitem__(self,index):
x_path,y_path = self.imgs[index]
img_x = Image.open(x_path)
img_y = Image.open(y_path)
if self.transform is not None:
img_x = self.transform(img_x)
if self.target_transform is not None:
img_y = self.target_transform(img_y)
return img_x,img_y#返回的是圖片
def __len__(self):
return len(self.imgs)#400,list[i]有兩個元素,[img,mask]
main.py
import torch
from torchvision.transforms import transforms as T
import argparse #argparse模塊的作用是用於解析命令行參數,例如python parseTest.py input.txt --port=8080
import unet
from torch import optim
from dataset import LiverDataset
from torch.utils.data import DataLoader
# 是否使用current cuda device or torch.device('cuda:0')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
x_transform = T.Compose([
T.ToTensor(),
# 標準化至[-1,1],規定均值和標準差
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])#torchvision.transforms.Normalize(mean, std, inplace=False)
])
# mask只需要轉換爲tensor
y_transform = T.ToTensor()
def train_model(model,criterion,optimizer,dataload,num_epochs=20):
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
dataset_size = len(dataload.dataset)
epoch_loss = 0
step = 0 #minibatch數
for x, y in dataload:# 分100次遍歷數據集,每次遍歷batch_size=4
optimizer.zero_grad()#每次minibatch都要將梯度(dw,db,...)清零
inputs = x.to(device)
labels = y.to(device)
outputs = model(inputs)#前向傳播
loss = criterion(outputs, labels)#計算損失
loss.backward()#梯度下降,計算出梯度
optimizer.step()#更新參數一次:所有的優化器Optimizer都實現了step()方法來對所有的參數進行更新
epoch_loss += loss.item()
step += 1
print("%d/%d,train_loss:%0.3f" % (step, dataset_size // dataload.batch_size, loss.item()))
print("epoch %d loss:%0.3f" % (epoch, epoch_loss))
torch.save(model.state_dict(),'weights_%d.pth' % epoch)# 返回模型的所有內容
return model
#訓練模型
def train():
model = unet.UNet(3,1).to(device)
batch_size = args.batch_size
#損失函數
criterion = torch.nn.BCELoss()
#梯度下降
optimizer = optim.Adam(model.parameters())#model.parameters():Returns an iterator over module parameters
#加載數據集
liver_dataset = LiverDataset("data/train", transform=x_transform, target_transform=y_transform)
dataloader = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True,num_workers=4)
# DataLoader:該接口主要用來將自定義的數據讀取接口的輸出或者PyTorch已有的數據讀取接口的輸入按照batch size封裝成Tensor
# batch_size:how many samples per minibatch to load,這裏爲4,數據集大小400,所以一共有100個minibatch
# shuffle:每個epoch將數據打亂,這裏epoch=10。一般在訓練數據中會採用
# num_workers:表示通過多個進程來導入數據,可以加快數據導入速度
train_model(model,criterion,optimizer,dataloader)
#測試
def test():
model = unet.UNet(3,1)
model.load_state_dict(torch.load(args.weight,map_location='cpu'))
liver_dataset = LiverDataset("data/val", transform=x_transform, target_transform=y_transform)
dataloaders = DataLoader(liver_dataset)#batch_size默認爲1
model.eval()
import matplotlib.pyplot as plt
plt.ion()
with torch.no_grad():
for x, _ in dataloaders:
y=model(x)
img_y=torch.squeeze(y).numpy()
plt.imshow(img_y)
plt.pause(0.01)
plt.show()
if __name__ == '__main__':
#參數解析
parser = argparse.ArgumentParser() #創建一個ArgumentParser對象
parser.add_argument('action', type=str, help='train or test')#添加參數
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--weight', type=str, help='the path of the mode weight file')
args = parser.parse_args()
if args.action == 'train':
train()
elif args.action == 'test':
test()