Pytorch入門——用UNet網絡做圖像分割

最近看的paper裏的pytorch代碼太複雜,我之前也沒接觸過pytorch,遂決定先自己實現一個基礎的裸代碼,這樣走一遍,對跑網絡的基本流程和一些常用的基礎函數的印象會更深刻。

本文的代碼和數據主要來自https://blog.csdn.net/jiangpeng59/article/details/80189889

附上該博主的github地址:https://github.com/JavisPeng/u_net_liver

並在自己的理解的基礎上做了一些改動,以及加了大量註釋。

如有錯誤,歡迎指出。

 unet.py(實現unet網絡)

import torch.nn as nn
import torch

class DoubleConv(nn.Module):
    def __init__(self,in_ch,out_ch):
        super(DoubleConv,self).__init__()
        self.conv = nn.Sequential(
                nn.Conv2d(in_ch,out_ch,3,padding=1),#in_ch、out_ch是通道數
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace = True),
                nn.Conv2d(out_ch,out_ch,3,padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace = True)  
            )
    def forward(self,x):
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self,in_ch,out_ch):
        super(UNet,self).__init__()
        self.conv1 = DoubleConv(in_ch,64)
        self.pool1 = nn.MaxPool2d(2)#每次把圖像尺寸縮小一半
        self.conv2 = DoubleConv(64,128)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(128,256)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(256,512)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = DoubleConv(512,1024)
        #逆卷積
        self.up6 = nn.ConvTranspose2d(1024,512,2,stride=2)
        self.conv6 = DoubleConv(1024,512)
        self.up7 = nn.ConvTranspose2d(512,256,2,stride=2)
        self.conv7 = DoubleConv(512,256)
        self.up8 = nn.ConvTranspose2d(256,128,2,stride=2)
        self.conv8 = DoubleConv(256,128)
        self.up9 = nn.ConvTranspose2d(128,64,2,stride=2)
        self.conv9 = DoubleConv(128,64)
        
        self.conv10 = nn.Conv2d(64,out_ch,1)
        
    
    def forward(self,x):
        c1 = self.conv1(x)
        p1 = self.pool1(c1)
        c2 = self.conv2(p1)
        p2 = self.pool2(c2)
        c3 = self.conv3(p2)
        p3 = self.pool3(c3)
        c4 = self.conv4(p3)
        p4 = self.pool4(c4)
        c5 = self.conv5(p4)
        up_6 = self.up6(c5)
        merge6 = torch.cat([up_6,c4],dim=1)#按維數1(列)拼接,列增加
        c6 = self.conv6(merge6)
        up_7 = self.up7(c6)
        merge7 = torch.cat([up_7,c3],dim=1)
        c7 = self.conv7(merge7)
        up_8 = self.up8(c7)
        merge8 = torch.cat([up_8,c2],dim=1)
        c8 = self.conv8(merge8)
        up_9 = self.up9(c8)
        merge9 = torch.cat([up_9,c1],dim=1)
        c9 = self.conv9(merge9)
        c10 = self.conv10(c9)
        
        out = nn.Sigmoid()(c10)#化成(0~1)區間
        return out
        

 dataset.py

import torch.utils.data as data
import os
import PIL.Image as Image

#data.Dataset:
#所有子類應該override__len__和__getitem__,前者提供了數據集的大小,後者支持整數索引,範圍從0到len(self)

class LiverDataset(data.Dataset):
    #創建LiverDataset類的實例時,就是在調用init初始化
    def __init__(self,root,transform = None,target_transform = None):#root表示圖片路徑
        n = len(os.listdir(root))//2 #os.listdir(path)返回指定路徑下的文件和文件夾列表。/是真除法,//對結果取整
        
        imgs = []
        for i in range(n):
            img = os.path.join(root,"%03d.png"%i)#os.path.join(path1[,path2[,......]]):將多個路徑組合後返回
            mask = os.path.join(root,"%03d_mask.png"%i)
            imgs.append([img,mask])#append只能有一個參數,加上[]變成一個list
        
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
    
    
    def __getitem__(self,index):
        x_path,y_path = self.imgs[index]
        img_x = Image.open(x_path)
        img_y = Image.open(y_path)
        if self.transform is not None:
            img_x = self.transform(img_x)
        if self.target_transform is not None:
            img_y = self.target_transform(img_y)
        return img_x,img_y#返回的是圖片
    
    
    def __len__(self):
        return len(self.imgs)#400,list[i]有兩個元素,[img,mask]

main.py

import torch
from torchvision.transforms import transforms as T
import argparse #argparse模塊的作用是用於解析命令行參數,例如python parseTest.py input.txt --port=8080
import unet
from torch import optim
from dataset import LiverDataset
from torch.utils.data import DataLoader


# 是否使用current cuda device or torch.device('cuda:0')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

x_transform = T.Compose([
    T.ToTensor(),
    # 標準化至[-1,1],規定均值和標準差
    T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])#torchvision.transforms.Normalize(mean, std, inplace=False)
])
# mask只需要轉換爲tensor
y_transform = T.ToTensor()

def train_model(model,criterion,optimizer,dataload,num_epochs=20):
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        dataset_size = len(dataload.dataset)
        epoch_loss = 0
        step = 0 #minibatch數
        for x, y in dataload:# 分100次遍歷數據集,每次遍歷batch_size=4
            optimizer.zero_grad()#每次minibatch都要將梯度(dw,db,...)清零
            inputs = x.to(device)
            labels = y.to(device)
            outputs = model(inputs)#前向傳播
            loss = criterion(outputs, labels)#計算損失
            loss.backward()#梯度下降,計算出梯度
            optimizer.step()#更新參數一次:所有的優化器Optimizer都實現了step()方法來對所有的參數進行更新
            epoch_loss += loss.item()
            step += 1
            print("%d/%d,train_loss:%0.3f" % (step, dataset_size // dataload.batch_size, loss.item()))
        print("epoch %d loss:%0.3f" % (epoch, epoch_loss))
    torch.save(model.state_dict(),'weights_%d.pth' % epoch)# 返回模型的所有內容
    return model

#訓練模型
def train():
    model = unet.UNet(3,1).to(device)
    batch_size = args.batch_size
    #損失函數
    criterion = torch.nn.BCELoss()
    #梯度下降
    optimizer = optim.Adam(model.parameters())#model.parameters():Returns an iterator over module parameters
    #加載數據集
    liver_dataset = LiverDataset("data/train", transform=x_transform, target_transform=y_transform)
    dataloader = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True,num_workers=4)
    # DataLoader:該接口主要用來將自定義的數據讀取接口的輸出或者PyTorch已有的數據讀取接口的輸入按照batch size封裝成Tensor
    # batch_size:how many samples per minibatch to load,這裏爲4,數據集大小400,所以一共有100個minibatch
    # shuffle:每個epoch將數據打亂,這裏epoch=10。一般在訓練數據中會採用
    # num_workers:表示通過多個進程來導入數據,可以加快數據導入速度 
    train_model(model,criterion,optimizer,dataloader)

#測試
def test():
    model = unet.UNet(3,1)
    model.load_state_dict(torch.load(args.weight,map_location='cpu'))
    liver_dataset = LiverDataset("data/val", transform=x_transform, target_transform=y_transform)
    dataloaders = DataLoader(liver_dataset)#batch_size默認爲1
    model.eval()
    import matplotlib.pyplot as plt
    plt.ion()
    with torch.no_grad():
        for x, _ in dataloaders:
            y=model(x)
            img_y=torch.squeeze(y).numpy()
            plt.imshow(img_y)
            plt.pause(0.01)
        plt.show()


if __name__ == '__main__':
    #參數解析
    parser = argparse.ArgumentParser() #創建一個ArgumentParser對象
    parser.add_argument('action', type=str, help='train or test')#添加參數
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument('--weight', type=str, help='the path of the mode weight file')
    args = parser.parse_args()
    
    if args.action == 'train':
        train()
    elif args.action == 'test':
        test()

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章