天池比賽代碼

deeplabv3plus
圖像處理代碼

import os
from PIL import Image
from configs import cfg
import numpy as np
import cv2
import random


#生成數據
def gendata_pro(unit_sizes=[2048, 1024, 512], begin_ind=0, randis=1,stride_rato=0.5,balance=0,mxnm = 10000):

    save_img_dir=os.path.join(cfg.DATA_DIR ,'image')
    save_mask_dir = os.path.join(cfg.DATA_DIR,  'mask')

    Image.MAX_IMAGE_PIXELS = 5000000000
    img = Image.open(cfg.DATA_DIR + 'jingwei_round1_train_20190619/image_2.png')
    img = np.asarray(img)
    print(img.shape)

    anno_map = Image.open(cfg.DATA_DIR + 'jingwei_round1_train_20190619/image_2_label.png')
    anno_map = np.asarray(anno_map)
    print(anno_map.shape)
    length, width = img.shape[0], img.shape[1]

    for unit_size in unit_sizes:
        nullthresh = unit_size * unit_size * 0.7
        maxthresh=unit_size * unit_size * 0.3
        count_cls = np.zeros(4)
        ind = 0
        if not os.path.exists(save_img_dir + str(unit_size)):
            os.makedirs(save_img_dir + str(unit_size))
        if not os.path.exists(save_mask_dir+ str(unit_size)):
            os.makedirs(save_mask_dir + str(unit_size))

        def save_img(x1,x2,y1,y2):
            global ind
            im = img[x1:x2, y1:y2, :]
            if (im[:, :, 0] == 0).sum() > nullthresh:
                return 0
            save_img = np.array(im[:, :, 0:3])
            save_mask = np.array(anno_map[x1:x2, y1:y2])
            if balance:
                num_cls = np.array([(save_mask == p).sum() for p in range(4)])
                if (count_cls[0] < mxnm and num_cls[0] > maxthresh):
                    count_cls[0] += 1
                elif (count_cls[1] < mxnm and num_cls[1] > maxthresh):
                    count_cls[1] += 1
                elif (count_cls[2] < mxnm and num_cls[2] > maxthresh):
                    count_cls[2] += 1
                elif (count_cls[3] < mxnm and num_cls[3] > maxthresh):
                    count_cls[3] += 1
                else:
                    if (count_cls.sum() == mxnm * 4):
                        return 1
                    else:
                        return 0
            bd = ind + begin_ind
            ind = ind + 1
            cv2.imwrite(save_img_dir + str(unit_size) + '/%06d.jpg' % bd, save_img)
            cv2.imwrite(save_img_dir + str(unit_size) + '/%06d.png' % bd, save_mask)
            return 0


        if randis:
            randnum = 200000 * 1.0 / unit_size
            randnum = randnum * randnum
            print(unit_size, randnum)

            for i in range(int(randnum)):
                x1, y1 = random.randint(0, length), random.randint(0, width)
                x2, y2 = x1 + unit_size, y1 + unit_size
                if x2 > length:
                    x2, x1 = length, length - unit_size
                if y2 > width:
                    y2, y1 = width, width - unit_size
                if save_img(x1, x2, y1, y2)==1:
                    return
        else:
            x1=0
            while (x1 < length):
                x2=x1+unit_size
                if x2 > length:
                    x2, x1 = length, length - unit_size
                y1=0
                while (y1 < width):
                    y2 = y1 + unit_size
                    if y2 > width:
                        y2, y1 = width, width - unit_size
                    if save_img(x1, x2, y1, y2) == 1:
                        return
                    y1 += unit_size*stride_rato
                x1 += unit_size*stride_rato


#生成label
def genlabel(unit_sizes=[2048, 1024, 512]):
    import random
    save_img_dir=os.path.join(cfg.DATA_DIR ,'image')
    save_mask_dir = os.path.join(cfg.DATA_DIR,  'mask')
    split_rato=0.8
    for unit_size in unit_sizes:

        train_txt=open(cfg.DATA_DIR+'train'+str(unit_size)+'.txt','w+')
        test_txt=open(cfg.DATA_DIR+'test'+str(unit_size)+'.txt','w+')

        img_list=os.listdir(save_img_dir+str(unit_size))
        mask_list = os.listdir(save_mask_dir+str(unit_size))
        id_list=[img_name.split('.')[0] for img_name in img_list]
        for id in id_list:
            s=random.random()
            if id+'.jpg' not in img_list:
                raise ValueError('label gen error')
            if id+'.png' not in mask_list:
                raise ValueError('label gen error')
            if(s>split_rato):
                test_txt.write('image' + str(unit_size)+ ' / '+ id + '.jpg mask' + str(unit_size) + '/' + id + '.png\n')
            else:
                test_txt.write('image' + str(unit_size) + '/' + id + '.jpg mask' + str(unit_size) + '/' + id + '.png\n')
        train_txt.close()
        test_txt.close()


#提交圖片生成
def genbig(unit_sizes=[512],stride_rato=0.5):
    from  test import  *
    img=cv2.imread(cfg.DATA_DIR+ 'jingwei_round1_test_a_20190619/image_4.png')

    length, width = img.shape[0], img.shape[1]
    mask = np.zeros((length, width))
    print(img.shape)

    for unit_size in unit_sizes:
        x1 = 0
        while (x1 < length):
            x2 = x1 + unit_size
            if x2 > length:
                x2, x1 = length, length - unit_size
            y1 = 0
            while (y1 < width):
                y2 = y1 + unit_size
                if y2 > width:
                    y2, y1 = width, width - unit_size
                im = img[x1:x2, y1:y2, :]
                if (im.max() > 0):
                    continue
                    #cv2.imwrite("tmp.jpg", im)
                    #result = pred("tmp.jpg")
                    #mask[x1:x2, y1:y2] = result
                    #mix_img = maskAddImg(im, result)
                    #cv2.imshow('mix_img', mix_img)
                    #cv2.waitKey(500)
                    # end
                y1 += unit_size * stride_rato
            x1 += unit_size * stride_rato
    cv2.imwrite("mask.png",mask)

#添加mask
def maskAddImg(img, mask):
    mask_red=np.zeros_like(mask)
    mask_green = np.zeros_like(mask)
    mask_blue = np.zeros_like(mask)
    mask_red[mask==1]=1
    mask_green[mask == 2]=1
    mask_blue[mask == 3]=1
    mask_img_n = np.stack((mask_red, mask_green, mask_blue), axis=2)
    mix_img = cv2.addWeighted(img, 0.5, mask_img_n * 255, 0.5, 1)
    return mix_img

genbig()
#genlabel()
#gendata_pro()

config代碼

import torch
import argparse
import os
import sys
import cv2
import time
class Configuration():
	def __init__(self):
		self.ROOT_DIR ='./'
		self.EXP_NAME = 'deeplabv3+tianchi'


		self.DATA_DIR = "D:/lengxia/code/deeplabv3plus/data/"
		self.TXT_LIST = ["test512.txt","test1024.txt"]
		self.DATA_NAME = 'tianchi'
		self.DATA_AUG = False
		self.DATA_WORKERS = 4
		self.DATA_RESCALE = 512
		self.DATA_RANDOMCROP = 0
		self.DATA_RANDOMROTATION = 0
		self.DATA_RANDOMSCALE = 1
		self.DATA_RANDOM_H = 10
		self.DATA_RANDOM_S = 10
		self.DATA_RANDOM_V = 10
		self.DATA_RANDOMFLIP = 0.5
		self.DATA_SPLIT=6
		
		self.MODEL_NAME = 'deeplabv3plus'
		self.MODEL_BACKBONE = 'res101_atrous'
		self.MODEL_OUTPUT_STRIDE = 16
		self.MODEL_ASPP_OUTDIM = 256
		self.MODEL_SHORTCUT_DIM = 48
		self.MODEL_SHORTCUT_KERNEL = 1
		self.MODEL_NUM_CLASSES = 4
		self.MODEL_SAVE_DIR = os.path.join(self.ROOT_DIR,'model',self.EXP_NAME)

		self.TRAIN_LR = 0.0007
		self.TRAIN_LR_GAMMA = 0.1
		self.TRAIN_MOMENTUM = 0.9
		self.TRAIN_WEIGHT_DECAY = 0.00004
		self.TRAIN_BN_MOM = 0.0003
		self.TRAIN_POWER = 0.9
		self.TRAIN_GPUS = 1
		self.GPUS_ID=[0]
		self.TRAIN_BATCHES = 8
		self.TRAIN_SHUFFLE = True
		self.TRAIN_MINEPOCH = 0	
		self.TRAIN_EPOCHS = 60
		self.TRAIN_LOSS_LAMBDA = 0
		self.TRAIN_TBLOG = True
		self.TRAIN_CKPT = None#os.path.join(self.ROOT_DIR,'model/deeplabv3+voc/deeplabv3plus_res101_atrous_VOC2012_epoch46_all.pth')

		self.LOG_DIR = os.path.join(self.ROOT_DIR,'log',self.EXP_NAME)

		self.TEST_MULTISCALE = [1.0]#[0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
		self.TEST_FLIP = False#True
		self.TEST_CKPT = os.path.join(self.ROOT_DIR,'model/deeplabv3+tianchi/deeplabv3plus_res101_atrous_tianchi_itr15000_256_2.pth')
		self.TEST_GPUS = 1
		self.TEST_BATCHES = 1

		self.__check()
		#self.__add_path(os.path.join(self.ROOT_DIR, 'lib'))
		
	def __check(self):
		if not torch.cuda.is_available():
			raise ValueError('configs.py: cuda is not avalable')
		if self.TRAIN_GPUS == 0:
			raise ValueError('configs.py: the number of GPU is 0')
		if not os.path.isdir(self.LOG_DIR):
			os.makedirs(self.LOG_DIR)
		if not os.path.isdir(self.MODEL_SAVE_DIR):
			os.makedirs(self.MODEL_SAVE_DIR)

	def __add_path(self, path):
		if path not in sys.path:
			sys.path.insert(0, path)



cfg = Configuration() 	


訓練代碼

# ----------------------------------------
# Written by Yude Wang
# ----------------------------------------

import torch
import torch.nn as nn
import os
import numpy as np
from configs import cfg
from lib.net.generateNet import generate_net
import torch.optim as optim
from torch.utils.data import DataLoader
from lib.net.loss import MaskCrossEntropyLoss, MaskBCELoss, MaskBCEWithLogitsLoss
from lib.net.sync_batchnorm.replicate import patch_replication_callback
import cv2
import TianchiDataset

def collate_fn(batch):
	images = []
	seg = []
	cs = []
	rs = []
	names = []
	for _,sample in enumerate(batch):
		images.append(sample['image'])
		seg.append(sample['segmentation'])
		rs.append(sample['row'])
		cs.append(sample['col'])
		names.append(sample['name'])
		print(sample['image'].shape)
	return {
		'image': torch.stack(images,0),
		'segmentation': torch.stack(seg,0),
			}

def train_net():
	dataset = TianchiDataset(cfg.DATA_NAME, cfg, 'train')
	dataloader = DataLoader(dataset, 
				batch_size=cfg.TRAIN_BATCHES, 
				shuffle=cfg.TRAIN_SHUFFLE, 
				num_workers=cfg.DATA_WORKERS,
				collate_fn=collate_fn,
				drop_last=True)
	net = generate_net(cfg)

	if cfg.TRAIN_TBLOG:
		from tensorboardX import SummaryWriter
		tblogger = SummaryWriter(cfg.LOG_DIR)
	print('Use %d GPU'%cfg.TRAIN_GPUS)
	device = torch.device(cfg.GPUS_ID[0])

	if cfg.TRAIN_GPUS > 1:
		net = nn.DataParallel(net,device_ids=cfg.GPUS_ID)
		patch_replication_callback(net)
	net.to(device)		

	if cfg.TRAIN_CKPT:
		pretrained_dict = torch.load(cfg.TRAIN_CKPT)
		net_dict = net.state_dict()
		pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in net_dict) and (v.shape==net_dict[k].shape)}
		net_dict.update(pretrained_dict)
		net.load_state_dict(net_dict)
		# net.load_state_dict(torch.load(cfg.TRAIN_CKPT),False)
	
	criterion = nn.CrossEntropyLoss(ignore_index=255)
	optimizer = optim.SGD(
		params = [
			{'params': get_params(net.module if cfg.TRAIN_GPUS>1 else net,key='1x'), 'lr': cfg.TRAIN_LR},
			{'params': get_params(net.module if cfg.TRAIN_GPUS>1 else net,key='10x'), 'lr': 10*cfg.TRAIN_LR}
		],
		momentum=cfg.TRAIN_MOMENTUM
	)

	itr = cfg.TRAIN_MINEPOCH * len(dataloader)
	max_itr = cfg.TRAIN_EPOCHS*len(dataloader)
	running_loss10 = 0.0
	running_loss100 = 0.0
	tblogger = SummaryWriter(cfg.LOG_DIR)
	net.train()
	for epoch in range(cfg.TRAIN_MINEPOCH, cfg.TRAIN_EPOCHS):
		for i_batch, sample_batched in enumerate(dataloader):
			now_lr = adjust_lr(optimizer, itr, max_itr)
			inputs_batched, labels_batched = sample_batched['image'], sample_batched['segmentation']
			optimizer.zero_grad()

			labels_batched = labels_batched.long().to(device)
			inputs_batched=inputs_batched.to(device)

			predicts_batched = net(inputs_batched)
			loss = criterion(predicts_batched, labels_batched)
			loss.backward()
			optimizer.step()
			running_loss10 += loss.item()
			running_loss100 += loss.item()
			if i_batch % 10 == 1:
				print('epoch:%d/%d\tbatch:%d/%d\titr:%d\tlr:%g\tloss:%g ' %
					(epoch, cfg.TRAIN_EPOCHS, i_batch, dataset.__len__()//cfg.TRAIN_BATCHES,
					itr+1, now_lr, running_loss10/10))
				running_loss10 = 0.0

			if cfg.TRAIN_TBLOG and itr%50 == 0:
				inputs = inputs_batched[0].cpu().numpy()/2.0 + 0.5
				labels = labels_batched[0].cpu().numpy()
				predicts = torch.argmax(predicts_batched[0], dim=0).cpu().numpy()
				labels_color = dataset.label2colormap(labels).transpose((2,0,1))
				predicts_color = dataset.label2colormap(predicts).transpose((2,0,1))
				pix_acc = np.sum(labels==predicts)/(cfg.DATA_RESCALE**2)

				tblogger.add_scalar('loss', running_loss100/100, itr)
				tblogger.add_scalar('lr', now_lr, itr)
				tblogger.add_scalar('pixel acc', pix_acc, itr)
				tblogger.add_image('Input', inputs, itr)
				tblogger.add_image('Label', labels_color, itr)
				tblogger.add_image('Output', predicts_color, itr)
				#cv2.imshow("label", labels_color.transpose(1, 2, 0))
				#cv2.imshow("pred", predicts_color.transpose(1, 2, 0))
				#cv2.waitKey(2000)
				#cv2.destroyAllWindows()
				running_loss100 = 0.0
			if itr % 5000 == 0:
				save_path = os.path.join(cfg.MODEL_SAVE_DIR,'%s_%s_%s_itr%d.pth'%(cfg.MODEL_NAME,cfg.MODEL_BACKBONE,cfg.DATA_NAME,itr))
				torch.save(net.state_dict(), save_path)
				print('%s has been saved'%save_path)
			itr += 1
		
	save_path = os.path.join(cfg.MODEL_SAVE_DIR,'%s_%s_%s_epoch%d_all.pth'%(cfg.MODEL_NAME,cfg.MODEL_BACKBONE,cfg.DATA_NAME,cfg.TRAIN_EPOCHS))		
	torch.save(net.state_dict(),save_path)
	if cfg.TRAIN_TBLOG:
		tblogger.close()
	print('%s has been saved'%save_path)

def adjust_lr(optimizer, itr, max_itr):
	now_lr = cfg.TRAIN_LR * (1 - itr/(max_itr+1)) ** cfg.TRAIN_POWER
	optimizer.param_groups[0]['lr'] = now_lr
	optimizer.param_groups[1]['lr'] = 10*now_lr
	return now_lr

def get_params(model, key):
	for m in model.named_modules():
		if key == '1x':
			if 'backbone' in m[0] and isinstance(m[1], nn.Conv2d):
				for p in m[1].parameters():
					yield p
		elif key == '10x':
			if 'backbone' not in m[0] and isinstance(m[1], nn.Conv2d):
				for p in m[1].parameters():
					yield p

if __name__ == '__main__':
	train_net()



測試代碼

# ----------------------------------------
# Written by Yude Wang
# ----------------------------------------

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
from configs import cfg
from transform import ToTensor
from lib.net.generateNet import generate_net
from lib.net.sync_batchnorm.replicate import patch_replication_callback
from torch.utils.data import DataLoader
from TianchiDataset import TianchiDataset


class TianchiModel:
	def __init__(self,cfg):
		self.net = generate_net(cfg)
		self.cfg=cfg
		self.tensor = ToTensor()
		print('net initialize')
	def inittest(self):
		if self.cfg.TEST_CKPT is None:
			raise ValueError('test.py: cfg.MODEL_CKPT can not be empty in test period')
		print('Use %d GPU' % self.cfg.TEST_GPUS)
		device = torch.device(self.cfg.GPUS_ID[0])
		if self.cfg.TEST_GPUS > 1:
			self.net = nn.DataParallel(self.net)
			patch_replication_callback(self.net)
		self.net.to(device)
		print('start loading model %s' % self.cfg.TEST_CKPT)
		model_dict = torch.load(self.cfg.TEST_CKPT, map_location=device)
		self.net.load_state_dict(model_dict)
		self.net.eval()

	def pred(self,img):
		#get cv : BGR
		image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
		r, c, _ = image.shape
		sample = {'image': image, 'name': "ok", 'row': r, 'col': c}
		sample = self.tensor(sample)
		inputs_batched = sample['image'].unsqueeze(0).to(self.cfg.GPUS_ID[0])
		predicts = self.net(inputs_batched).to(self.cfg.GPUS_ID[0])
		predicts_batched = predicts.clone()
		predicts_batched = F.interpolate(predicts_batched, size=None, scale_factor=1 / 1.0, mode='bilinear',
										 align_corners=True)
		result = torch.argmax(predicts_batched, dim=1).cpu().numpy().astype(np.uint8)
		return result[0]


def test_net():
	dataset = TianchiDataset(cfg, 'test')
	dataloader = DataLoader(dataset, 
				batch_size=cfg.TEST_BATCHES, 
				shuffle=False, 
				num_workers=cfg.DATA_WORKERS)
	
	net = generate_net(cfg)
	print('net initialize')
	if cfg.TEST_CKPT is None:
		raise ValueError('test.py: cfg.MODEL_CKPT can not be empty in test period')

	print('Use %d GPU'%cfg.TEST_GPUS)
	device = torch.device('cuda')
	#net = nn.DataParallel(net)
	#patch_replication_callback(net)
	net.to(device)

	print('start loading model %s'%cfg.TEST_CKPT)
	model_dict = torch.load(cfg.TEST_CKPT,map_location=device)
	net.load_state_dict(model_dict)
	
	net.eval()	
	result_list = []
	with torch.no_grad():
		hist=np.zeros((4,4))
		for i_batch, sample_batched in enumerate(dataloader):
			name_batched = sample_batched['name']
			row_batched = sample_batched['row']
			col_batched = sample_batched['col']

			[batch, channel, height, width] = sample_batched['image'].size()
			multi_avg = torch.zeros((batch, cfg.MODEL_NUM_CLASSES, height, width), dtype=torch.float32).to(0)
			for rate in cfg.TEST_MULTISCALE:
				inputs_batched = sample_batched['image_%f'%rate]
				inputs_batched = inputs_batched.cuda(device)
				predicts = net(inputs_batched).to(device)
				predicts_batched = predicts.clone()
				del predicts
				if cfg.TEST_FLIP:
					inputs_batched_flip = torch.flip(inputs_batched,[3]) 
					predicts_flip = torch.flip(net(inputs_batched_flip),[3]).to(device)
					predicts_batched_flip = predicts_flip.clone()
					del predicts_flip
					predicts_batched = (predicts_batched + predicts_batched_flip) / 2.0
			
				predicts_batched = F.interpolate(predicts_batched, size=None, scale_factor=1/rate, mode='bilinear', align_corners=True)
				multi_avg = multi_avg + predicts_batched
				del predicts_batched
			
			multi_avg = multi_avg / len(cfg.TEST_MULTISCALE)
			result = torch.argmax(multi_avg, dim=1).cpu().numpy().astype(np.uint8)
			predicts_color = dataset.label2colormap(result[0])
			labels_batched = sample_batched['segmentation']
			labels = labels_batched[0].cpu().numpy().astype(np.uint8)
			if(labels.max()>0):
				hist += fast_hist(labels,result[0], 4)
				print(per_class_iu(hist))

			labels_color = dataset.label2colormap(labels)
			#print(predicts_color.shape)
			orimg= cv2.resize(sample_batched['orimg'][0].numpy(), dsize=(512,512), interpolation=cv2.INTER_CUBIC).astype(np.uint8)
			#print(orimg.shape,predicts_color.shape,type(predicts_color),type(orimg))
			mix_img = maskAddImg(orimg, result[0])
			cv2.imshow('pred', mix_img)
			mix_img = maskAddImg(orimg, labels)
			cv2.imshow('label', mix_img)
			#cv2.imshow("result", predicts_color)
			#cv2.imshow("label", labels_color)
			cv2.waitKey(500)
		print("all",per_class_iu(hist))



def fast_hist(a, b, n):
    k = (a >=0) & (a <n)
    return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)
def per_class_iu(hist):
    miou=np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
    miou[miou!=miou]=1
    return miou


def maskAddImg(img, mask):
    mask_red=np.zeros_like(mask)
    mask_green = np.zeros_like(mask)
    mask_blue = np.zeros_like(mask)
    mask_red[mask==1]=1
    mask_green[mask == 2]=1
    mask_blue[mask == 3]=1

    mask_img_n = np.stack((mask_red, mask_green, mask_blue), axis=2)
    print(img.shape, mask_img_n.shape, type(img), type(mask_img_n))
    mix_img = cv2.addWeighted(img, 0.5, mask_img_n * 255, 0.5, 1)
    return mix_img




if __name__ == '__main__':
	test_net()
	#tianchimodel=TianchiModel(cfg)



數據塊代碼

# ----------------------------------------
# Written by Yude Wang
# ----------------------------------------

from __future__ import print_function, division
import sys
import os
import torch
import cv2
import multiprocessing
from PIL import Image
import numpy as np
from torch.utils.data import Dataset
from transform import *
import random

class TianchiDataset(Dataset):
    def __init__(self, cfg, period):
        self.dataset_name = cfg.DATA_NAME
        self.dataset_dir =cfg.DATA_DIR
        self.name_list=[]
        self.period=period
        for list_txt in cfg.TXT_LIST:
            with open(self.dataset_dir+list_txt,'r+') as fp:
                self.name_list.extend(fp.readlines())
        self.rescale = None
        self.centerlize = None
        self.randomcrop = None
        self.randomflip = None
        self.randomrotation = None
        self.randomscale = None
        self.randomhsv = None
        self.multiscale = None
        self.totensor = ToTensor()
        self.cfg = cfg
        self.nums = int(len(self.name_list) / self.cfg.DATA_SPLIT)
        if  'tianchi' in self.dataset_name:
            self.categories = [
                'kaoyan',  # 1
                'yumi',  # 2
                'yirenmi',  # 3
               ]  # 3
            self.num_categories = len(self.categories)
            assert (self.num_categories + 1 == self.cfg.MODEL_NUM_CLASSES)


        if cfg.DATA_RESCALE > 0:
            self.rescale = Rescale(cfg.DATA_RESCALE, fix=False)
        if 'train' in self.period:
            if cfg.DATA_RANDOMCROP > 0:
                self.randomcrop = RandomCrop(cfg.DATA_RANDOMCROP)
            if cfg.DATA_RANDOMROTATION > 0:
                self.randomrotation = RandomRotation(cfg.DATA_RANDOMROTATION)
            if cfg.DATA_RANDOMSCALE != 1:
                self.randomscale = RandomScale(cfg.DATA_RANDOMSCALE)
            if cfg.DATA_RANDOMFLIP > 0:
                self.randomflip = RandomFlip(cfg.DATA_RANDOMFLIP)
            if cfg.DATA_RANDOM_H > 0 or cfg.DATA_RANDOM_S > 0 or cfg.DATA_RANDOM_V > 0:
                self.randomhsv = RandomHSV(cfg.DATA_RANDOM_H, cfg.DATA_RANDOM_S, cfg.DATA_RANDOM_V)
        else:
            self.multiscale = Multiscale(self.cfg.TEST_MULTISCALE)

    def __len__(self):
        return self.nums

    def __getitem__(self, idx):
        idx = random.randint(0, self.cfg.DATA_SPLIT-1) * self.nums + idx
        name = self.name_list[idx]
        img_file = self.dataset_dir + name.split()[0]
        image = cv2.imread(img_file)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        r, c, _ = image.shape
        sample = {'image': image, 'name': name, 'row': r, 'col': c,'orimg':image}
        seg_file = self.dataset_dir + name.split()[1]
        segmentation = np.array(Image.open(seg_file))
        sample['segmentation'] = segmentation
        if 'train' in self.period:
            if self.cfg.DATA_RANDOM_H > 0 or self.cfg.DATA_RANDOM_S > 0 or self.cfg.DATA_RANDOM_V > 0:
                sample = self.randomhsv(sample)
            if self.cfg.DATA_RANDOMFLIP > 0:
                sample = self.randomflip(sample)
            if self.cfg.DATA_RANDOMROTATION > 0:
                sample = self.randomrotation(sample)
            if self.cfg.DATA_RANDOMSCALE != 1:
                sample = self.randomscale(sample)
            if self.cfg.DATA_RANDOMCROP > 0:
                sample = self.randomcrop(sample)
            if self.cfg.DATA_RESCALE > 0:
                # sample = self.centerlize(sample)
                sample = self.rescale(sample)
        else:
            if self.cfg.DATA_RESCALE > 0:
                sample = self.rescale(sample)
            sample = self.multiscale(sample)
        sample = self.totensor(sample)
        #print(sample['image'].shape, sample['segmentation'].shape, sample['r'], sample['c'])
        return sample
    def label2colormap(self, label):
        m = label.astype(np.uint8)
        r,c = m.shape
        cmap = np.zeros((r,c,3), dtype=np.uint8)
        cmap[:,:,0] = (m&1)<<7 | (m&8)<<3
        cmap[:,:,1] = (m&2)<<6 | (m&16)<<2
        cmap[:,:,2] = (m&4)<<5
        return cmap


def check_data():

    from data_process import maskAddImg
    from configs import cfg
    datasets = TianchiDataset(dataset_name='tainchi',cfg=cfg,period='train')
    for indx in range(len(datasets)):
        sample = datasets.__getitem__(indx)
        img = sample['image']#get RGB cv show(BGR)
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        mask = sample['segmentation']
        try:
            mix_img = maskAddImg(img, mask)
        except:
            print(sample['name'])

        mix_img_s = np.concatenate((sample['orimg'],img, mix_img), 1)
        cv2.imshow('mix_img', mix_img_s )
        k = cv2.waitKey(0)
        if k == ord('q'):
            cv2.destroyAllWindows()
            break
    cv2.destroyAllWindows()

#check_data()
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章