[pytorch]醫學圖像之肝臟語義分割(訓練+預測代碼)

一,Unet結構:

結合上圖的Unet結構,pytorch的unet代碼如下:

unet.py:

import torch.nn as nn
import torch
from torch import autograd


class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, input):
        return self.conv(input)


class Unet(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Unet, self).__init__()

        self.conv1 = DoubleConv(in_ch, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = DoubleConv(512, 1024)
        self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.conv6 = DoubleConv(1024, 512)
        self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv7 = DoubleConv(512, 256)
        self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv8 = DoubleConv(256, 128)
        self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv9 = DoubleConv(128, 64)
        self.conv10 = nn.Conv2d(64, out_ch, 1)

    def forward(self, x):
        c1 = self.conv1(x)
        p1 = self.pool1(c1)
        c2 = self.conv2(p1)
        p2 = self.pool2(c2)
        c3 = self.conv3(p2)
        p3 = self.pool3(c3)
        c4 = self.conv4(p3)
        p4 = self.pool4(c4)
        c5 = self.conv5(p4)
        up_6 = self.up6(c5)
        merge6 = torch.cat([up_6, c4], dim=1)
        c6 = self.conv6(merge6)
        up_7 = self.up7(c6)
        merge7 = torch.cat([up_7, c3], dim=1)
        c7 = self.conv7(merge7)
        up_8 = self.up8(c7)
        merge8 = torch.cat([up_8, c2], dim=1)
        c8 = self.conv8(merge8)
        up_9 = self.up9(c8)
        merge9 = torch.cat([up_9, c1], dim=1)
        c9 = self.conv9(merge9)
        c10 = self.conv10(c9)
        out = nn.Sigmoid()(c10)
        return out

二,數據集:肝臟圖

三,代碼

主代碼:main.py

import numpy as np
import torch
import argparse
from torch.utils.data import DataLoader
from torch import autograd, optim
from torchvision.transforms import transforms
from unet import Unet
from dataset import LiverDataset
from mIou import *
import os
import cv2
# 是否使用cuda
import PIL.Image as Image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_data(i):
    import dataset
    imgs = dataset.make_dataset(r"H:\BaiduNetdisk\BaiduDownload\u_net_liver-master\data\val")
    imgx = []
    imgy = []
    for img in imgs:
        imgx.append(img[0])
        imgy.append(img[1])
    return imgx[i],imgy[i]



def train_model(model, criterion, optimizer, dataload, num_epochs=21):
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        dt_size = len(dataload.dataset)
        epoch_loss = 0
        step = 0
        for x, y in dataload:
            step += 1
            inputs = x.to(device)
            labels = y.to(device)
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            print("%d/%d,train_loss:%0.3f" % (step, (dt_size - 1) // dataload.batch_size + 1, loss.item()))
        print("epoch %d loss:%0.3f" % (epoch, epoch_loss))
    #torch.save(model.state_dict(), './weights_%d.pth' % epoch)
    torch.save(model.state_dict(), r'H:\BaiduNetdisk\BaiduDownload\u_net_liver-master/weights.pth')
    return model


# 訓練模型
def train():
    model = Unet(3, 1).to(device)
    batch_size = args.batch_size
    criterion = torch.nn.BCELoss()
    optimizer = optim.Adam(model.parameters())
    liver_dataset = LiverDataset(r"H:\BaiduNetdisk\BaiduDownload\u_net_liver-master\data\train", transform=x_transforms, target_transform=y_transforms)
    dataloaders = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    train_model(model, criterion, optimizer, dataloaders)


# 顯示模型的輸出結果
def test():
    model = Unet(3, 1).to(device)   #unet輸入是三通道,輸出是一通道,因爲不算上背景只有肝臟一個類別
    model.load_state_dict(torch.load(args.ckp, map_location='cpu'))  #載入訓練好的模型
    liver_dataset = LiverDataset(r"H:\BaiduNetdisk\BaiduDownload\u_net_liver-master\data\val", transform=x_transforms, target_transform=y_transforms)
    dataloaders = DataLoader(liver_dataset, batch_size=1)
    model.eval()
    import matplotlib.pyplot as plt
    plt.ion() #開啓動態模式



    with torch.no_grad():
        i=0   #驗證集中第i張圖
        miou_total = 0
        num = len(dataloaders)  #驗證集圖片的總數
        for x, _  in dataloaders:
            x = x.to(device)
            y = model(x)
            
            img_y = torch.squeeze(y).cpu().numpy()  #輸入損失函數之前要把預測圖變成numpy格式,且爲了跟訓練圖對應,要額外加多一維表示batchsize
            mask = get_data(i)[1]    #得到當前mask的路徑
            miou_total += get_iou(mask,img_y)  #獲取當前預測圖的miou,並加到總miou中
            plt.subplot(121)
            plt.imshow(Image.open(get_data(i)[0]))  
            plt.subplot(122)
            plt.imshow(img_y)
            plt.pause(0.01)
            if i < num:i+=1   #處理驗證集下一張圖
        plt.show()
        print('Miou=%f' % (miou_total / 20))

if __name__ =="__main__":
    x_transforms = transforms.Compose([
        transforms.ToTensor(),  # -> [0,1]
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # ->[-1,1]
    ])

    # mask只需要轉換爲tensor
    y_transforms = transforms.ToTensor()

    # 參數解析器,用來解析從終端讀取的命令
    parse = argparse.ArgumentParser()
    #parse = argparse.ArgumentParser()
    parse.add_argument("--action", type=str, help="train or test",default="train")
    parse.add_argument("--batch_size", type=int, default=1)
    parse.add_argument("--ckp", type=str, help="the path of model weight file")
    args = parse.parse_args()

    # train
    # train()        #測試時,就把此train()語句註釋掉

    # test()
    args.ckp = r"H:\BaiduNetdisk\BaiduDownload\u_net_liver-master\weights.pth"
    test()

獲取數據代碼:

dataset.py:

import torch.utils.data as data
import PIL.Image as Image
import os


def make_dataset(root):
    imgs = []
    n = len(os.listdir(root)) // 2  #因爲數據集中一套訓練數據包含有訓練圖和mask圖,所以要除2
    for i in range(n):
        img = os.path.join(root, "%03d.png" % i)
        mask = os.path.join(root, "%03d_mask.png" % i)
        imgs.append((img, mask))
    return imgs


class LiverDataset(data.Dataset):
    def __init__(self, root, transform=None, target_transform=None):
        imgs = make_dataset(root)
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        x_path, y_path = self.imgs[index]
        origin_x = Image.open(x_path)
        origin_y = Image.open(y_path)
        if self.transform is not None:
            img_x = self.transform(origin_x)
        if self.target_transform is not None:
            img_y = self.target_transform(origin_y)
        return img_x, img_y

    def __len__(self):
        return len(self.imgs)

四,網絡結果:

網絡結果一般用

1.直觀效果 或者

2.指標來指定

語義分割中有一個重要的指標就是miou,平均交併比,其中miou的代碼如下:

mIou.py:

import cv2
import numpy as np

class IOUMetric:
    """
    Class to calculate mean-iou using fast_hist method
    """

    def __init__(self, num_classes):
        self.num_classes = num_classes
        self.hist = np.zeros((num_classes, num_classes))

    def _fast_hist(self, label_pred, label_true):
        mask = (label_true >= 0) & (label_true < self.num_classes)
        hist = np.bincount(
            self.num_classes * label_true[mask].astype(int) +
            label_pred[mask], minlength=self.num_classes ** 2).reshape(self.num_classes, self.num_classes)
        return hist

    def add_batch(self, predictions, gts):
        for lp, lt in zip(predictions, gts):
            self.hist += self._fast_hist(lp.flatten(), lt.flatten())

    def evaluate(self):
        acc = np.diag(self.hist).sum() / self.hist.sum()
        acc_cls = np.diag(self.hist) / self.hist.sum(axis=1)
        acc_cls = np.nanmean(acc_cls)
        iu = np.diag(self.hist) / (self.hist.sum(axis=1) + self.hist.sum(axis=0) - np.diag(self.hist))
        mean_iu = np.nanmean(iu)
        freq = self.hist.sum(axis=1) / self.hist.sum()
        fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
        return acc, acc_cls, iu, mean_iu, fwavacc



def get_iou(mask_name,predict):
    image_mask = cv2.imread(mask_name,0)
    # print(image.shape)
    height = predict.shape[0]
    weight = predict.shape[1]
    # print(height*weight)
    o = 0
    for row in range(height):
            for col in range(weight):
                if predict[row, col] < 0.5:  #由於輸出的predit是0~1範圍的,其中值越靠近1越被網絡認爲是肝臟目標,所以取0.5爲閾值
                    predict[row, col] = 0
                else:
                    predict[row, col] = 1
                if predict[row, col] == 0 or predict[row, col] == 1:
                    o += 1
    height_mask = image_mask.shape[0]
    weight_mask = image_mask.shape[1]
    for row in range(height_mask):
            for col in range(weight_mask):
                if image_mask[row, col] < 125:   #由於mask圖是黑白的灰度圖,所以少於125的可以看作是黑色
                    image_mask[row, col] = 0
                else:
                    image_mask[row, col] = 1
                if image_mask[row, col] == 0 or image_mask[row, col] == 1:
                    o += 1
    predict = predict.astype(np.int16)

    Iou = IOUMetric(2)  #2表示類別,肝臟類+背景類
    Iou.add_batch(predict, image_mask)
    a, b, c, d, e= Iou.evaluate()
    print('%s:iou=%f' % (mask_name,d))
    return d

 

五,運行結果:

驗證集的miou:

六,代碼和數據集獲取:

代碼和數據集都在以下鏈接了:

the Liver dataset: link:https://pan.baidu.com/s/1FljGCVzu7HPYpwAKvSVN4Q keyword:5l88

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章