Github复现之D-LinkNet(补全了验证部分代码,效果还行的,一定要看喔)

链接:https://github.com/zlkanata/DeepGlobe-Road-Extraction-Challenge
多一嘴,这里面还带了unet,可以跑跑对比下别的unet哪个效果好喔
这个项目原本就是做道路分割的,但是不止在道路上表现好,其他的地方也不错,我这里复现有点不一样的地方是为了跟其他网络对比,这里就不同原始的数据扩充部分了,直接读原始图像训练看效果,并且加了验证部分的代码,包括进度条(这个一定会是你喜欢的,哈哈好像并没有推荐重点,不过这个真的不是你轻易能找到的答案)、验证loss、iou等,数据扩充部分本来就有的,觉得好可以自己试试。
数据链接:https://pan.baidu.com/s/1GHrf4YWaCCSNq_dYFpIfAA
提取码:2snq
觉得有帮助要点赞喔(数据都给你们做好啦,不要白嫖啊),爱你么么哒
训练了100epoch,环境cuda8.0,cudnn 6.1,pytorch0.4.1(请注意,不要急着换环境,高版本训练这个应该是没问题的)
先上结果喽:
IOU是0.60824211没做数据扩充,原始数据随便训练的结果,还行了,自己调整一下会很好的。

('acc: ', 0.9790308253835388)
('acc_cls: ', 0.8492737379119986)
('iou: ', array([0.97832516, 0.60824211]))
('miou: ', 0.7932836363940091)
('fwavacc: ', 0.9612672801739732)
('class_accuracy: ', 0.8141106006333385)
('class_recall: ', 0.7063404781913989)
('accuracy: ', 0.9790308253835388)
('f1_score: ', 0.7564061467818185)

截取部分验证图像:
原图
部分验证图像对应的结果
在这里插入图片描述
数据存放结构
训练、验证都是以下面的方式放的,注意图像和标签在数字相同的情况下一个是_sat一个是_mask
数据结构
下面就上代码,有改动的我都会全贴出来的喔
训练代码train.py

import torch
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn as nn
import torch.utils.data as data
from torch.autograd import Variable as V
import cv2
import os
import warnings
from tqdm import tqdm
import numpy as np
from time import time
from networks.unet import Unet
from networks.dunet import Dunet
from networks.dinknet import LinkNet34, DinkNet34, DinkNet50, DinkNet101, DinkNet34_less_pool
from framework import MyFrame
from loss import dice_bce_loss
from data import ImageFolder
import torch.nn.functional as F
# from test import TTAFrame



def iou(img_true, img_pred):
    img_pred = (img_pred > 0).float()
    i = (img_true * img_pred).sum()
    u = (img_true + img_pred).sum()
    return i / u if u != 0 else u

def iou_metric(imgs_pred, imgs_true):
    num_images = len(imgs_true)
    scores = np.zeros(num_images)
    for i in range(num_images):
        if imgs_true[i].sum() == imgs_pred[i].sum() == 0:
            scores[i] = 1
        else:
            # iou_thresholds = np.array([0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95])
            # scores[i] = (iou_thresholds <= iou(imgs_true[i], imgs_pred[i])).mean()
            scores[i] = iou(imgs_true[i], imgs_pred[i])
    return scores.mean()

def get_one_hot(label, N):
    size = list(label.size())
    label = label.view(-1)
    ones = torch.sparse.torch.eye(N)
    ones = ones.index_select(0, label)
    size.append(N)
    return ones.view(*size)

class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()
 
    def forward(self, input, target):
        N = target.size(0)
        smooth = 1
 
        input_flat = input.view(N, -1)
        target_flat = target.view(N, -1)
 
        intersection = input_flat * target_flat
 
        loss = 2 * (intersection.sum(1) + smooth) / (input_flat.sum(1) + target_flat.sum(1) + smooth)
        loss = 1 - loss.sum() / N
 
        return loss

class MulticlassDiceLoss(nn.Module):
    """
    requires one hot encoded target. Applies DiceLoss on each class iteratively.
    requires input.shape[0:1] and target.shape[0:1] to be (N, C) where N is
      batch size and C is number of classes
    """
    def __init__(self):
        super(MulticlassDiceLoss, self).__init__()
 
    def forward(self, input, target, weights=None):
 
        C = target.shape[1]
 
        # if weights is None:
        #   weights = torch.ones(C) #uniform weights for all classes
 
        dice = DiceLoss()
        totalLoss = 0
 
        for i in range(C):
            diceLoss = dice(input[:,i], target[:,i])
            if weights is not None:
                diceLoss *= weights[i]
            totalLoss += diceLoss
 
        return totalLoss

                    
class SoftIoULoss(nn.Module):
    def __init__(self, n_classes):
        super(SoftIoULoss, self).__init__()
        self.n_classes = n_classes

    @staticmethod
    def to_one_hot(tensor, n_classes):
        n, h, w = tensor.size()
        one_hot = torch.zeros(n, n_classes, h, w).scatter_(1, tensor.view(n, 1, h, w), 1)
        return one_hot

    def forward(self, input, target):
        # logit => N x Classes x H x W
        # target => N x H x W

        N = len(input)

        pred = F.softmax(input, dim=1)
        target_onehot = self.to_one_hot(target, self.n_classes)

        # Numerator Product
        inter = pred * target_onehot
        # Sum over all pixels N x C x H x W => N x C
        inter = inter.view(N, self.n_classes, -1).sum(2)

        # Denominator
        union = pred + target_onehot - (pred * target_onehot)
        # Sum over all pixels N x C x H x W => N x C
        union = union.view(N, self.n_classes, -1).sum(2)

        loss = inter / (union + 1e-16)

        # Return average loss over classes and batch
        return -loss.mean()

def iou_pytorch(outputs: torch.Tensor, labels: torch.Tensor, SMOOTH = 1e-6):
    # You can comment out this line if you are passing tensors of equal shape
    # But if you are passing output from UNet or something it will most probably
    # be with the BATCH x 1 x H x W shape
    outputs = outputs.squeeze(1)  # BATCH x 1 x H x W => BATCH x H x W
    
    intersection = (outputs & labels).float().sum((1, 2))  # Will be zero if Truth=0 or Prediction=0
    union = (outputs | labels).float().sum((1, 2))         # Will be zzero if both are 0
    
    iou = (intersection + SMOOTH) / (union + SMOOTH)  # We smooth our devision to avoid 0/0
    
    thresholded = torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10  # This is equal to comparing with thresolds
    
    return thresholded.mean()  # Or thresholded.mean() if you are interested in average across the batch

# Numpy version
# Well, it's the same function, so I'm going to omit the comments

def iou_numpy(outputs: np.array, labels: np.array):
    outputs = outputs.squeeze(1)
    
    intersection = (outputs & labels).sum((1, 2))
    union = (outputs | labels).sum((1, 2))
    
    iou = (intersection + SMOOTH) / (union + SMOOTH)
    
    thresholded = np.ceil(np.clip(20 * (iou - 0.5), 0, 10)) / 10
    
    return thresholded  # Or thresholded.mean()


if __name__ == '__main__':
    warnings.filterwarnings("ignore")
    SHAPE = (512,512)
    train_root = 'D:/complete_project/Dinknet/road512/train/'
    imagelist = filter(lambda x: x.find('sat')!=-1, os.listdir(train_root))
    trainlist = map(lambda x: x[:-8], imagelist)
    trainlist = list(trainlist)

    val_root = 'D:/complete_project/Dinknet/road512/val/'
    imagelist = filter(lambda x: x.find('sat')!=-1, os.listdir(val_root))
    vallist = map(lambda x: x[:-8], imagelist)
    vallist = list(vallist)

    NAME = 'dinknet3'
    BATCHSIZE_PER_CARD = 8 #每个显卡给batchsize给8

    solver = MyFrame(DinkNet34, dice_bce_loss, 1e-3)
    # solver.load('./weights/test.th')

    train_batchsize = torch.cuda.device_count() * BATCHSIZE_PER_CARD
    val_batchsize = torch.cuda.device_count() * BATCHSIZE_PER_CARD * 2

    train_dataset = ImageFolder(trainlist, train_root)
    val_dataset = ImageFolder(vallist, val_root)

    data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size = train_batchsize,
        shuffle=True,
        num_workers=0)
    
    val_data_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size = val_batchsize,
        shuffle=True,
        num_workers=0)


    mylog = open('logs/'+NAME+'.log','w')
    tic = time()
    device = torch.device('cuda:0')
    no_optim = 0
    total_epoch = 100
    train_epoch_best_loss = 100.
    
    test_loss = 0
    # criteon = nn.CrossEntropyLoss().to(device)
    criteon = DiceLoss()
    iou_criteon = SoftIoULoss(2)
    # scheduler = solver.lr_strategy()  这个学习率调整策略是我加的,还没用,只用了原始的,感兴趣的可以试试
    for epoch in range(1, total_epoch + 1):
        print('---------- Epoch:'+str(epoch)+ ' ----------')
        # scheduler.step() 对应上面的学习率策略,须同时打开
        # print('lr={:.6f}'.format(scheduler.get_lr()[0])) 输出上面的学习率策略,须同时打开
        data_loader_iter = iter(data_loader)
        train_epoch_loss = 0
        print('Train:')
        for img, mask in tqdm(data_loader_iter,ncols=20,total=len(data_loader_iter)):
            solver.set_input(img, mask)
            train_loss = solver.optimize()
            train_epoch_loss += train_loss
        train_epoch_loss /= len(data_loader_iter)
        
        val_data_loader_num = iter(val_data_loader)
        test_epoch_loss = 0
        test_mean_iou = 0
        val_pre_list = []
        val_mask_list = []
        print('Validation:')
        for val_img, val_mask in tqdm(val_data_loader_num,ncols=20,total=len(val_data_loader_num)):
            val_img, val_mask = val_img.to(device), val_mask.cuda()
            val_mask[np.where(val_mask > 0)] = 1
            val_mask = val_mask.squeeze(0)
            predict = solver.test_one_img(val_img)
            predict_temp = torch.from_numpy(predict).unsqueeze(0)
            predict_use = V(predict_temp.type(torch.FloatTensor),volatile=True)
            val_use = V(val_mask.type(torch.FloatTensor),volatile=True)
            test_epoch_loss += criteon.forward(predict_use,val_use)
            predict_use = predict_use.squeeze(0)
            predict_use = predict_use.unsqueeze(1)
            predict_use[predict_use >= 0.5] = 1
            predict_use[predict_use < 0.5] = 0
            predict_use = predict_use.type(torch.LongTensor)
            val_use = val_use.squeeze(1).type(torch.LongTensor)
            test_mean_iou += iou_pytorch(predict_use, val_use)
        

        batch_iou = test_mean_iou / len(val_data_loader_num)
        val_loss = test_epoch_loss / len(val_data_loader_num)
        
        mylog.write('********************' + '\n')
        mylog.write('--epoch:'+ str(epoch) + '  --time:' + str(int(time()-tic)) + '  --train_loss:' + str(train_epoch_loss.item()) + ' --val_loss:' + str(val_loss.item()) + ' --val_iou:' + str(batch_iou.item()) +'\n')
        print('--epoch:', epoch, '  --time:', int(time()-tic), '  --train_loss:', train_epoch_loss.item(), ' --val_loss:',val_loss.item(), ' --val_iou:',batch_iou.item())
        if train_epoch_loss >= train_epoch_best_loss:
            no_optim += 1
        else:
            no_optim = 0
            train_epoch_best_loss = train_epoch_loss
            solver.save('weights/'+NAME+'.th')
        if no_optim > 6:
           print (mylog, 'early stop at %d epoch' % epoch)
           print ('early stop at %d epoch' % epoch)
           break
        if no_optim > 3:
            if solver.old_lr < 5e-7:
                break
            solver.load('weights/'+NAME+'.th')
            solver.update_lr(5.0, factor = True, mylog = mylog)
        mylog.flush()
        
    print(mylog, 'Finish!')
    print('Finish!')
    mylog.close()

framework.py 改的不多

import torch
import torch.nn as nn
from torch.autograd import Variable as V
from torch.optim import lr_scheduler

import cv2
import numpy as np

class MyFrame():
    def __init__(self, net, loss, lr=2e-4, evalmode = False):
        self.net = net().cuda()
        self.net = torch.nn.DataParallel(self.net, device_ids=range(torch.cuda.device_count()))
        self.optimizer = torch.optim.Adam(params=self.net.parameters(), lr=lr)
        #self.optimizer = torch.optim.RMSprop(params=self.net.parameters(), lr=lr)

        self.loss = loss()
        self.old_lr = lr
        if evalmode:
            for i in self.net.modules():
                if isinstance(i, nn.BatchNorm2d):
                    i.eval()
        
    def set_input(self, img_batch, mask_batch=None, img_id=None):
        self.img = img_batch
        self.mask = mask_batch
        self.img_id = img_id
        
    def test_one_img(self, img):  #注释了一部分的
        pred = self.net.forward(img)
        
        # pred[pred>0.5] = 1
        # pred[pred<=0.5] = 0

        # mask = pred.squeeze().cpu().data.numpy()
        mask = pred.squeeze().cpu().data.numpy()
        return mask
    
    def test_batch(self):
        self.forward(volatile=True)
        mask =  self.net.forward(self.img).cpu().data.numpy().squeeze(1)
        mask[mask>0.5] = 1
        mask[mask<=0.5] = 0
        
        return mask, self.img_id
    
    def test_one_img_from_path(self, path):
        img = cv2.imread(path)
        img = np.array(img, np.float32)/255.0 * 3.2 - 1.6
        img = V(torch.Tensor(img).cuda())
        
        mask = self.net.forward(img).squeeze().cpu().data.numpy()#.squeeze(1)
        mask[mask>0.5] = 1
        mask[mask<=0.5] = 0  
        return mask
        
    def forward(self, volatile=False):
        self.img = V(self.img.cuda(), volatile=volatile)
        if self.mask is not None:
            self.mask = V(self.mask.cuda(), volatile=volatile)
        
    def optimize(self):
        self.forward()
        self.optimizer.zero_grad()
        pred = self.net.forward(self.img)
        loss = self.loss(self.mask, pred)
        loss.backward()
        self.optimizer.step()
        return loss.data[0]
        
    def save(self, path):
        torch.save(self.net.state_dict(), path)
        
    def load(self, path):
        self.net.load_state_dict(torch.load(path))
    
    def update_lr(self, new_lr, mylog, factor=False):
        if factor:
            new_lr = self.old_lr / new_lr
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = new_lr

        print(mylog, 'update learning rate: %f -> %f' % (self.old_lr, new_lr))
        print('update learning rate: %f -> %f' % (self.old_lr, new_lr))
        self.old_lr = new_lr

    def lr_strategy(self): #新加的
        # scheduler = lr_scheduler.StepLR(self.optimizer, step_size=30, gamma=0.1)
        # scheduler = lr_scheduler.MultiStepLR(self.optimizer, [30, 80], 0.1)
        scheduler = lr_scheduler.ExponentialLR(self.optimizer, gamma=0.9)
        return scheduler

data.py 数据入口,原本的数据输入是作者做了数据扩充的,这里我加了只导入原始数据的函数

"""
Based on https://github.com/asanakoy/kaggle_carvana_segmentation
"""
import torch
import torch.utils.data as data
from torch.autograd import Variable as V
import torchvision.transforms as transforms
import torchvision

import cv2
import numpy as np
import os

def randomHueSaturationValue(image, hue_shift_limit=(-180, 180),
                             sat_shift_limit=(-255, 255),
                             val_shift_limit=(-255, 255), u=0.5):
    if np.random.random() < u:
        image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
        h, s, v = cv2.split(image)
        hue_shift = np.random.randint(hue_shift_limit[0], hue_shift_limit[1]+1)
        hue_shift = np.uint8(hue_shift)
        h += hue_shift
        sat_shift = np.random.uniform(sat_shift_limit[0], sat_shift_limit[1])
        s = cv2.add(s, sat_shift)
        val_shift = np.random.uniform(val_shift_limit[0], val_shift_limit[1])
        v = cv2.add(v, val_shift)
        image = cv2.merge((h, s, v))
        #image = cv2.merge((s, v))
        image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)

    return image

def randomShiftScaleRotate(image, mask,
                           shift_limit=(-0.0, 0.0),
                           scale_limit=(-0.0, 0.0),
                           rotate_limit=(-0.0, 0.0), 
                           aspect_limit=(-0.0, 0.0),
                           borderMode=cv2.BORDER_CONSTANT, u=0.5):
    if np.random.random() < u:
        height, width, channel = image.shape

        angle = np.random.uniform(rotate_limit[0], rotate_limit[1])
        scale = np.random.uniform(1 + scale_limit[0], 1 + scale_limit[1])
        aspect = np.random.uniform(1 + aspect_limit[0], 1 + aspect_limit[1])
        sx = scale * aspect / (aspect ** 0.5)
        sy = scale / (aspect ** 0.5)
        dx = round(np.random.uniform(shift_limit[0], shift_limit[1]) * width)
        dy = round(np.random.uniform(shift_limit[0], shift_limit[1]) * height)

        cc = np.math.cos(angle / 180 * np.math.pi) * sx
        ss = np.math.sin(angle / 180 * np.math.pi) * sy
        rotate_matrix = np.array([[cc, -ss], [ss, cc]])

        box0 = np.array([[0, 0], [width, 0], [width, height], [0, height], ])
        box1 = box0 - np.array([width / 2, height / 2])
        box1 = np.dot(box1, rotate_matrix.T) + np.array([width / 2 + dx, height / 2 + dy])

        box0 = box0.astype(np.float32)
        box1 = box1.astype(np.float32)
        mat = cv2.getPerspectiveTransform(box0, box1)
        image = cv2.warpPerspective(image, mat, (width, height), flags=cv2.INTER_LINEAR, borderMode=borderMode,
                                    borderValue=(
                                        0, 0,
                                        0,))
        mask = cv2.warpPerspective(mask, mat, (width, height), flags=cv2.INTER_LINEAR, borderMode=borderMode,
                                   borderValue=(
                                       0, 0,
                                       0,))

    return image, mask

def randomHorizontalFlip(image, mask, u=0.5):
    if np.random.random() < u:
        image = cv2.flip(image, 1)
        mask = cv2.flip(mask, 1)

    return image, mask

def randomVerticleFlip(image, mask, u=0.5):
    if np.random.random() < u:
        image = cv2.flip(image, 0)
        mask = cv2.flip(mask, 0)

    return image, mask

def randomRotate90(image, mask, u=0.5):
    if np.random.random() < u:
        image=np.rot90(image)
        mask=np.rot90(mask)

    return image, mask

def default_loader(id, root): #这是作者原本用的数据导入,挺不错的,扩充了很多,你们可以抛下这个试试
    img = cv2.imread(os.path.join(root,'{}_sat.png').format(id))
    mask = cv2.imread(os.path.join(root+'{}_mask.png').format(id), cv2.IMREAD_GRAYSCALE)
    
    img = randomHueSaturationValue(img,
                                   hue_shift_limit=(-30, 30),
                                   sat_shift_limit=(-5, 5),
                                   val_shift_limit=(-15, 15))
    
    img, mask = randomShiftScaleRotate(img, mask,
                                       shift_limit=(-0.1, 0.1),
                                       scale_limit=(-0.1, 0.1),
                                       aspect_limit=(-0.1, 0.1),
                                       rotate_limit=(-0, 0))
    img, mask = randomHorizontalFlip(img, mask)
    img, mask = randomVerticleFlip(img, mask)
    img, mask = randomRotate90(img, mask)
    
    mask = np.expand_dims(mask, axis=2)
    img = np.array(img, np.float32).transpose(2,0,1)/255.0 * 3.2 - 1.6
    mask = np.array(mask, np.float32).transpose(2,0,1)/255.0
    mask[mask>=0.5] = 1
    mask[mask<=0.5] = 0
    #mask = abs(mask-1)
    return img, mask

def own_loader(id, root):  #这里只导入了原始数据的,用的就是这个
    img = cv2.imread(os.path.join(root,'{}_sat.png').format(id))
    mask = cv2.imread(os.path.join(root+'{}_mask.png').format(id), cv2.IMREAD_GRAYSCALE)
    mask = np.expand_dims(mask, axis=2)
    # img = np.array(img, np.float32).transpose(2,0,1)/255.0 * 3.2 - 1.6
    # mask = np.array(mask, np.float32).transpose(2,0,1)/255.0
    mask[mask>=0.5] = 1
    mask[mask<=0.5] = 0
    img = np.array(img, np.float32).transpose(2,0,1)
    mask = np.array(mask, np.float32).transpose(2,0,1)
    return img, mask

class ImageFolder(data.Dataset):

    def __init__(self, trainlist, root):
        self.ids = trainlist
        # self.loader = default_loader  #原始的
        self.loader = own_loader  #用了自己的数据导入
        self.root = root
        # self.trans = transforms.Compose([transforms.ToTensor()])
        # self.trans = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])

    def __getitem__(self, index):
        id = list(self.ids)[index]
        img, mask = self.loader(id, self.root)
        # img = np.transpose(img, (1,2,0))
        # img = self.trans(img)
        img = torch.Tensor(img)
        mask = torch.Tensor(mask)
        return img, mask

    def __len__(self):
        return len(list(self.ids))

test.py 预测代码

import torch
import torch.nn as nn
import torch.utils.data as data
from torch.autograd import Variable as V

import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
import pickle

from time import time

from networks.unet import Unet
from networks.dunet import Dunet
from networks.dinknet import LinkNet34, DinkNet34, DinkNet50, DinkNet101, DinkNet34_less_pool


BATCHSIZE_PER_CARD = 8

class TTAFrame():
    def __init__(self, net):
        self.net = net().cuda()
        self.net = torch.nn.DataParallel(self.net, device_ids=range(torch.cuda.device_count()))
        
    def test_one_img_from_path(self, path, evalmode = True):
        if evalmode:
            self.net.eval()
        batchsize = torch.cuda.device_count() * BATCHSIZE_PER_CARD
        if batchsize >= 8:
            return self.test_one_img_from_path_1(path)
        elif batchsize >= 4:
            return self.test_one_img_from_path_2(path)
        elif batchsize >= 2:
            return self.test_one_img_from_path_4(path)

    def test_one_img_from_path_8(self, path):
        img = cv2.imread(path)#.transpose(2,0,1)[None]
        img90 = np.array(np.rot90(img))
        img1 = np.concatenate([img[None],img90[None]])
        img2 = np.array(img1)[:,::-1]
        img3 = np.array(img1)[:,:,::-1]
        img4 = np.array(img2)[:,:,::-1]
        
        img1 = img1.transpose(0,3,1,2)
        img2 = img2.transpose(0,3,1,2)
        img3 = img3.transpose(0,3,1,2)
        img4 = img4.transpose(0,3,1,2)
        
        img1 = V(torch.Tensor(np.array(img1, np.float32)/255.0 * 3.2 -1.6).cuda())
        img2 = V(torch.Tensor(np.array(img2, np.float32)/255.0 * 3.2 -1.6).cuda())
        img3 = V(torch.Tensor(np.array(img3, np.float32)/255.0 * 3.2 -1.6).cuda())
        img4 = V(torch.Tensor(np.array(img4, np.float32)/255.0 * 3.2 -1.6).cuda())
        
        maska = self.net.forward(img1).squeeze().cpu().data.numpy()
        maskb = self.net.forward(img2).squeeze().cpu().data.numpy()
        maskc = self.net.forward(img3).squeeze().cpu().data.numpy()
        maskd = self.net.forward(img4).squeeze().cpu().data.numpy()
        
        mask1 = maska + maskb[:,::-1] + maskc[:,:,::-1] + maskd[:,::-1,::-1]
        mask2 = mask1[0] + np.rot90(mask1[1])[::-1,::-1]
        
        return mask2

    def test_one_img_from_path_4(self, path):
        img = cv2.imread(path)#.transpose(2,0,1)[None]
        img90 = np.array(np.rot90(img))
        img1 = np.concatenate([img[None],img90[None]])
        img2 = np.array(img1)[:,::-1]
        img3 = np.array(img1)[:,:,::-1]
        img4 = np.array(img2)[:,:,::-1]
        
        img1 = img1.transpose(0,3,1,2)
        img2 = img2.transpose(0,3,1,2)
        img3 = img3.transpose(0,3,1,2)
        img4 = img4.transpose(0,3,1,2)
        
        img1 = V(torch.Tensor(np.array(img1, np.float32)/255.0 * 3.2 -1.6).cuda())
        img2 = V(torch.Tensor(np.array(img2, np.float32)/255.0 * 3.2 -1.6).cuda())
        img3 = V(torch.Tensor(np.array(img3, np.float32)/255.0 * 3.2 -1.6).cuda())
        img4 = V(torch.Tensor(np.array(img4, np.float32)/255.0 * 3.2 -1.6).cuda())
        
        maska = self.net.forward(img1).squeeze().cpu().data.numpy()
        maskb = self.net.forward(img2).squeeze().cpu().data.numpy()
        maskc = self.net.forward(img3).squeeze().cpu().data.numpy()
        maskd = self.net.forward(img4).squeeze().cpu().data.numpy()
        
        mask1 = maska + maskb[:,::-1] + maskc[:,:,::-1] + maskd[:,::-1,::-1]
        mask2 = mask1[0] + np.rot90(mask1[1])[::-1,::-1]
        
        return mask2
    
    def test_one_img_from_path_2(self, path):
        img = cv2.imread(path)#.transpose(2,0,1)[None]
        img90 = np.array(np.rot90(img))
        img1 = np.concatenate([img[None],img90[None]])
        img2 = np.array(img1)[:,::-1]
        img3 = np.concatenate([img1,img2])
        img4 = np.array(img3)[:,:,::-1]
        img5 = img3.transpose(0,3,1,2)
        img5 = np.array(img5, np.float32)/255.0 * 3.2 -1.6
        img5 = V(torch.Tensor(img5).cuda())
        img6 = img4.transpose(0,3,1,2)
        img6 = np.array(img6, np.float32)/255.0 * 3.2 -1.6
        img6 = V(torch.Tensor(img6).cuda())
        
        maska = self.net.forward(img5).squeeze().cpu().data.numpy()#.squeeze(1)
        maskb = self.net.forward(img6).squeeze().cpu().data.numpy()
        
        mask1 = maska + maskb[:,:,::-1]
        mask2 = mask1[:2] + mask1[2:,::-1]
        mask3 = mask2[0] + np.rot90(mask2[1])[::-1,::-1]
        
        return mask3
    
    def test_one_img_from_path_1(self, path):
        img = cv2.imread(path)#.transpose(2,0,1)[None]
        
        img90 = np.array(np.rot90(img))
        img1 = np.concatenate([img[None],img90[None]])
        img2 = np.array(img1)[:,::-1]
        img3 = np.concatenate([img1,img2])
        img4 = np.array(img3)[:,:,::-1]
        img5 = np.concatenate([img3,img4]).transpose(0,3,1,2)
        img5 = np.array(img5, np.float32)/255.0 * 3.2 -1.6
        img5 = V(torch.Tensor(img5).cuda())
        
        mask = self.net.forward(img5).squeeze().cpu().data.numpy()#.squeeze(1)
        mask1 = mask[:4] + mask[4:,:,::-1]
        mask2 = mask1[:2] + mask1[2:,::-1]
        mask3 = mask2[0] + np.rot90(mask2[1])[::-1,::-1]
        
        return mask3

    def load(self, path):
        self.net.load_state_dict(torch.load(path))
        
#source = 'dataset/test/'
source = 'D:/complete_project/Dinknet/road512/test/'
val = os.listdir(source)
solver = TTAFrame(DinkNet34)
# solver = TTAFrame(LinkNet34)
solver.load('weights/dinknet.th')
tic = time()
target = 'submits/log01_dink34/'
# os.mkdir(target)
for i,name in enumerate(val):
    if i%10 == 0:
        print(i/10, '    ','%.2f'%(time()-tic))
    mask = solver.test_one_img_from_path_8(source+name)  #这里好奇不?别激动,作者这里的意图应该是类似于归一化数据的,你们可以自己写一个不那么麻烦的,我还没时间试,先用这个吧,这种方式值得探究一下的
    mask[mask>4.0] = 255
    mask[mask<=4.0] = 0
    # mask = np.concatenate([mask[:,:,None],mask[:,:,None],mask[:,:,None]],axis=2)
    cv2.imwrite(target+name[:-7]+'mask.png',mask.astype(np.uint8))

eval.py 精度评定

# -*- coding: utf-8 -*-
import os
import cv2
import numpy as np
from sklearn.metrics import confusion_matrix

class IOUMetric:
    """
    Class to calculate mean-iou using fast_hist method
    """
    def __init__(self, num_classes):
        self.num_classes = num_classes
        self.hist = np.zeros((num_classes, num_classes))
    def _fast_hist(self, label_pred, label_true):
        mask = (label_true >= 0) & (label_true < self.num_classes)        
        hist = np.bincount(
            self.num_classes * label_true[mask].astype(int) +
            label_pred[mask], minlength=self.num_classes ** 2).reshape(self.num_classes, self.num_classes)
        return hist

    def evaluate(self, predictions, gts):
        for lp, lt in zip(predictions, gts):
            assert len(lp.flatten()) == len(lt.flatten())
            self.hist += self._fast_hist(lp.flatten(), lt.flatten())    
        # miou
        iou = np.diag(self.hist) / (self.hist.sum(axis=1) + self.hist.sum(axis=0) - np.diag(self.hist))
        miou = np.nanmean(iou) 
        # mean acc
        acc = np.diag(self.hist).sum() / self.hist.sum()
        acc_cls = np.nanmean(np.diag(self.hist) / self.hist.sum(axis=1))
        freq = self.hist.sum(axis=1) / self.hist.sum()
        fwavacc = (freq[freq > 0] * iou[freq > 0]).sum()
        return acc, acc_cls, iou, miou, fwavacc


if __name__ == '__main__':
    label_path = 'D:/complete_project/Dinknet/submits/label/'
    predict_path = 'D:/complete_project/Dinknet/submits/log01_dink34/'
    pres = os.listdir(predict_path)
    labels = []
    predicts = []
    for im in pres:
        if im[-4:] == '.png':
            label_name = im.split('.')[0] + '.png'
            lab_path = os.path.join(label_path, label_name)
            pre_path = os.path.join(predict_path, im)
            label = cv2.imread(lab_path,0)
            pre = cv2.imread(pre_path,0)
            label[label>0] = 1
            pre[pre>0] = 1
            labels.append(label)
            predicts.append(pre)
    el = IOUMetric(2)
    acc, acc_cls, iou, miou, fwavacc = el.evaluate(predicts, labels)
    print('acc: ',acc)
    print('acc_cls: ',acc_cls)
    print('iou: ',iou)
    print('miou: ',miou)
    print('fwavacc: ',fwavacc)

    pres = os.listdir(predict_path)
    init = np.zeros((2,2))
    for im in pres:
        lb_path = os.path.join(label_path, im)
        pre_path = os.path.join(predict_path, im)
        lb = cv2.imread(lb_path,0)
        pre = cv2.imread(pre_path,0)
        lb[lb>0] = 1
        pre[pre>0] = 1
        lb = lb.flatten()
        pre = pre.flatten()
        confuse = confusion_matrix(lb, pre)
        init += confuse

    precision = init[1][1]/(init[0][1] + init[1][1]) 
    recall = init[1][1]/(init[1][0] + init[1][1])
    accuracy = (init[0][0] + init[1][1])/init.sum()
    f1_score = 2*precision*recall/(precision + recall)
    print('class_accuracy: ', precision)
    print('class_recall: ', recall)
    print('accuracy: ', accuracy)
    print('f1_score: ', f1_score)

下面是可能用到的简单代码
resize.py 图像resize

import os, cv2
label_path = '/complete_project/Dinknet/submits/log01_dink34/'
out = 'D:/complete_project/Dinknet/submits/result/'
labels = os.listdir(label_path)
for label in labels:
    name = label.split('.')[0]
    lb_path = os.path.join(label_path, label)
    lb = cv2.imread(lb_path,0)
    lb = cv2.resize(lb, (500, 500), interpolation = cv2.INTER_NEAREST)
    # lb[lb>0] = 255
    cv2.imwrite(os.path.join(out, name[:-5]+'.png'),lb, [int(cv2.IMWRITE_PNG_COMPRESSION), 0]) #注意一下存出来的图像质量喔,不要损失信息

OK,到这里就完事儿了,其他没改的代码通过链接就可以下到了,这个只是代码优化了,训练参数什么的并没有调整优化喔,后面要靠自己了呢,加油喔,看完别忘了动动你们可爱的小手呀,给个赞,点个收藏也勉强能接受

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章