利用U-Net網絡對遙感影像道路信息分割提取

一、論文閱讀

原始論文是《U-Net: Convolutional Networks for Biomedical Image Segmentation》地址:https://arxiv.org/abs/1505.04597。其網絡結構主要是以“U”型編碼器-解碼器構成了下采樣-上採樣兩部分功能結構。下采樣採用典型的卷積網絡架構,就採樣結構結果而言,每層的Max-Pooling採樣減小了圖像尺寸,但是成倍增加了channels,具體每層卷積操作可以看代碼或者詳讀論文。上採用過程中對下采樣的結果進行Conv-Transpose反捲積過程,直到恢復網絡結構,網絡結構如圖1.1:

圖1.1 U-Net網絡架構

二、代碼實現

代碼分成了三個py文件,分別爲數據預處理模塊dataset.py,網絡模型實現模塊unet.py以及main.py。

# dataset.py
from torch.utils.data import Dataset
import PIL.Image as Image
import os
def make_dataset(root):
    imgs=[]
    n=len(os.listdir(root))//2
    for i in range(n):
        img=os.path.join(root,"%03d.png"%i)
        mask=os.path.join(root,"%03d_mask.png"%i)
        imgs.append((img,mask))
    return imgs
class LiverDataset(Dataset):
    def __init__(self, root, transform=None, target_transform=None):
        imgs = make_dataset(root)
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
    def __getitem__(self, index):
        x_path, y_path = self.imgs[index]
        img_x = Image.open(x_path)
        img_y = Image.open(y_path)
        if self.transform is not None:
            img_x = self.transform(img_x)
        if self.target_transform is not None:
            img_y = self.target_transform(img_y)
        return img_x, img_y
    def __len__(self):
        return len(self.imgs)

dataset.py中有兩個功能函數,make_dataset模塊是將樣本以及樣本標籤導入。LiverDataset模塊是爲了做DataLoader而準備。

# unet.py
import torch
from torch import nn
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    def forward(self, input):
        return self.conv(input)
class Unet(nn.Module):
    def __init__(self,in_ch,out_ch):
        super(Unet, self).__init__()
        self.conv1 = DoubleConv(in_ch, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = DoubleConv(512, 1024)
        self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.conv6 = DoubleConv(1024, 512)
        self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv7 = DoubleConv(512, 256)
        self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv8 = DoubleConv(256, 128)
        self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv9 = DoubleConv(128, 64)
        self.conv10 = nn.Conv2d(64,out_ch, 1)
    def forward(self,x):
        c1=self.conv1(x)
        p1=self.pool1(c1)
        c2=self.conv2(p1)
        p2=self.pool2(c2)
        c3=self.conv3(p2)
        p3=self.pool3(c3)
        c4=self.conv4(p3)
        p4=self.pool4(c4)
        c5=self.conv5(p4)
        up_6= self.up6(c5)
        merge6 = torch.cat([up_6, c4], dim=1)
        c6=self.conv6(merge6)
        up_7=self.up7(c6)
        merge7 = torch.cat([up_7, c3], dim=1)
        c7=self.conv7(merge7)
        up_8=self.up8(c7)
        merge8 = torch.cat([up_8, c2], dim=1)
        c8=self.conv8(merge8)
        up_9=self.up9(c8)
        merge9=torch.cat([up_9,c1],dim=1)
        c9=self.conv9(merge9)
        c10=self.conv10(c9)
        return c10

 

# main.py
import torch
import argparse
from torch.utils.data import DataLoader
from torch import nn, optim
from torchvision.transforms import transforms
from unet import Unet
from dataset import LiverDataset
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import cv2
import os
from tensorboardX import SummaryWriter
# 是否使用cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
# mask只需要轉換爲tensor
y_transforms = transforms.ToTensor()
def train_model(model, criterion, optimizer, dataload, num_epochs=3):
    writer = SummaryWriter(r'model_record')
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        dt_size = len(dataload.dataset)
        epoch_loss = 0
        step = 0
        for x, y in dataload:
            step += 1
            inputs = x.to(device)
            labels = y.to(device)
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            writer.add_scalar('train loss', loss.item(), global_step=step+epoch*200)
            print("%d/%d,train_loss:%0.3f" % (step, (dt_size - 1) // dataload.batch_size + 1, loss.item()))
        print("epoch %d loss:%0.3f" % (epoch, epoch_loss/step))
    torch.save(model.state_dict(), 'weights_%d.pth' % epoch)
    return model
#訓練模型
def train():
    model = Unet(3, 1).to(device)
    batch_size = 2
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters())
    liver_dataset = LiverDataset(r'data_road\train',transform=x_transforms,target_transform=y_transforms)
    dataloaders = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    train_model(model, criterion, optimizer, dataloaders)
#顯示模型的輸出結果
def test():
    model = Unet(3, 1)
    model.load_state_dict(torch.load(r'weights_5.pth',map_location=lambda storage, loc: storage.cuda(0)))
    # model.load_state_dict(torch.load(r'u_net_liver\weights_4.pth'))
    liver_dataset = LiverDataset(r'data_road\val1', transform=x_transforms,target_transform=y_transforms)
    dataloaders = DataLoader(liver_dataset, batch_size=1)
    model.eval()
    with torch.no_grad():
        all_IoU = 0.0
        record = 0
        for x, tagart in dataloaders:
            y=model(x)
            img_y=torch.squeeze(y).numpy()
            img_tagart = torch.squeeze(tagart).numpy()
            img_y[img_y > 0.3] = 255
            img_y[img_y <= 0.3] = 0
            # print([x for x in img_y if x in img_tagart])
            all = 0.0
            inte = 0.0
            for x1 in range(0,512):
                for x2 in range(0,512):
                    if img_y[x1,x2] == 255 and img_tagart[x1,x2] == 1:
                        all=all+1
                    if img_y[x1,x2] == 255 or img_tagart[x1,x2] == 1:
                        inte=inte+1
            all_IoU =all_IoU+ all/inte
            pathname = "%03d_predict.png"%record
            cv2.imwrite(os.path.join(r'data_road/result1',pathname),img_y)
            record=record+1
            print(all*1.0/inte)
        print(all_IoU/20.0)
if __name__ == '__main__':
    train()
    test()

main.py函數比較雜,其中我將訓練和測試函數都寫在了一起,在train時單獨運行train()將test()屏蔽即刻。採用的Loss函數是nn.BCEWithLogitsLoss(),這裏有興趣可以將其變換爲其他的loss看看結果。值得注意的一點是這裏有一個閾值0.3,對應的代碼是img_y[img_y > 0.3] = 255和img_y[img_y <= 0.3] = 0,這裏需要對每個不同情況自己去定義自己的分割閾值去確定。optimizer選取的是Adam。

數據集採用的是Massachusetts road,數據地址爲:https://www.cs.toronto.edu/~vmnih/data/,這裏可以用簡單爬蟲批量下載,若有需求,可以讓我在下面評論貼出該數據集的網盤地址。還有一個細節(坑)就是,unet結構是需要512或者1024等大小的16整數倍的image sizes。所以這裏需要對下載的數據集(images和labels)進行批量重採樣,採樣用最鄰近和雙線性內插均可,沒有太大影響,然而我是將數據集裁剪爲了512*512,because of graphics memory。最後貼出我的訓練結果。

三、結果討論

首先,我的迭代次數不太多,也沒有采用動態學習率策略,並且massachusetts road數據集也有很多坑(需要很多預處理,去除損壞樣本,誰用誰知道),所以最後的分割效果一般,僅僅是跑通網絡。結果見圖3.1,3.2和3.3。

3.1 Training Loss
3.2 Mean IoU
3.3 從左到右分別是原始圖像-真值-提取結果

 討論:本文簡要地用U-Net網絡跑了一下遙感影像道路信息分割提取這個方面的研究,效果達到預期但是沒有想象突出,原因有以下兩點,1、原始數據集Massachusetts roads樣本有部分有較大的偏差,真值存在錯誤致使訓練錯誤。2、Loss設計不合理,具體可以見有關遙感領域分割信息提取Loss設計相關論文,本人也在學習階段。

歡迎大家留言討論。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章