PyTorch 搭建FCN

實驗結論請看github

voc.py

'''
@Author  :   {AishuaiYao}
@License :   (C) Copyright 2013-2020, {None}
@Contact :   {[email protected]}
@Software:   ${segmentation}
@File    :   ${voc}.py
@Time    :   ${2020-04-04}
@Desc    :   deconvlution experiment
'''
 
import os
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.optim.lr_scheduler as lr_scheduler
from torchsummary import summary
from torchvision import transforms,datasets
from torch.utils.data import Dataset,DataLoader
from FCN.fcn import *
import numpy as np
import cv2
from PIL import Image
 
 
classes = ['background',    'aeroplane',    'bicycle',      'bird',         'boat',
           'bottle',        'bus',          'car',          'cat',          'chair',
           'cow',           'diningtable',  'dog',          'horse',        'motorbike',
           'person',        'potted plant', 'sheep',        'sofa',         'train',
           'tv/monitor']
 
# RGB color for each class
colormap = [[0,0,0],        [128,0,0],      [0,128,0],      [128,128,0],    [0,0,128],
            [128,0,128],    [0,128,128],    [128,128,128],  [64,0,0],       [192,0,0],
            [64,128,0],     [192,128,0],    [64,0,128],     [192,0,128],    [64,128,128],
            [192,128,128],  [0,64,0],       [128,64,0],     [0,192,0],      [128,192,0],
            [0,64,128]]
voc_path = '../data/VOC2012'
BATCH_SIZE = 1
num_classes = 21
epochs = 200
input_size = 512
 
 
def read_images(path = voc_path, train = True):
    file = path + '/ImageSets/Segmentation/' + ('train.txt' if train else 'val.txt')
    with open(file) as f:
        imgs = f.read().split()
    datas = [path + '/JPEGImages/%s.jpg'%img for img in imgs]
    labels = [path + '/SegmentationClass/%s.png'%img for img in imgs]
    return datas, labels
 
 
def preproccessing(datas,labels):
    for i,img, label in enumerate(zip(datas, labels)):
        img_canvas,label_canvas = img2label(img,label)
 
 
 
def img2label(img,label,canvas_size = input_size):
    img = cv2.imread(img)
    img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
    label = cv2.imread(label)
    label = cv2.cvtColor(label,cv2.COLOR_BGR2RGB)
 
    max_width, max_height = canvas_size,canvas_size
    height, width, channel = img.shape
    pad_width = (max_width - width) // 2
    pad_height = (max_height - height) // 2
 
    img_canvas = np.full((max_width, max_height, 3), 0)
    label_canvas = np.full((max_width, max_height, 3), 0)
    img_canvas[pad_height: pad_height + height, pad_width: pad_width + width, :] = img
    label_canvas[pad_height: pad_height + height, pad_width: pad_width + width, :] = label
 
    transform = transforms.Compose([transforms.ToTensor()])#,transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    img2tensor = transform(img_canvas)
 
    for i, cm in enumerate(colormap):
        label_canvas[
            (label_canvas[:, :, 0] == cm[0]) & (label_canvas[:, :, 1] == cm[1]) & (label_canvas[:, :, 2] == cm[2])] = i
    label_canvas = label_canvas[:, :, 0]
    label_canvas[label_canvas == 224] = 0
    label2tensor = torch.from_numpy(label_canvas)
 
    return img2tensor, label2tensor
 
 
class VOCSegGenerator(Dataset):
    def __init__(self,train,):
        super(VOCSegGenerator, self).__init__()
        self.data_list, self.label_list = read_images(path=voc_path, train = train)
        self.len = len(self.data_list)
        print('Read '+ str(self.len)+' images')
 
    def __getitem__(self,idx):
        img = self.data_list[idx]
        label = self.label_list[idx]
        img,label = img2label(img, label)
        return img,label
 
    def __len__(self):
        return self.len
 
 
train = VOCSegGenerator(train = True)
valid = VOCSegGenerator(train = False)
 
train_loader = DataLoader(dataset = train,batch_size=BATCH_SIZE,shuffle=True)
valid_loader = DataLoader(dataset = valid,batch_size=BATCH_SIZE)
 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = FCNx8_ResNet(num_classes)
 
# model.to(device)
 
device_ids = [2,3]
model = torch.nn.DataParallel(model, device_ids=device_ids) # 聲明所有可用設備
model = model.cuda(device=device_ids[0])  # 模型放在主設備
 
summary(model,(3,input_size,input_size))
model = torch.load('./model/fcnx8resnet50_1.pkl')
# optimizer = torch.optim.Adam(model.parameters(),lr=1e-2,weight_decay=1e-4)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, weight_decay=5*1e-4,momentum=0.9)
 
 
def train(model,device,train_loader,optimizer,epoch):
    model.train()
    for batch_idx,(data,label) in enumerate(train_loader):
        # data,label = data.to(device),label.to(device)
        data, label = data.cuda(device=device_ids[0]), label.cuda(device=device_ids[0])
        optimizer.zero_grad()
        output = model(data)
 
        output = F.log_softmax(output,dim=1)
        criterion = nn.NLLLoss()
        loss = criterion(output,label)
 
        loss.backward()
        optimizer.step()
        if (batch_idx) % 30 == 0:
            print('train {} epoch : {}/{} \t loss : {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()))
 
 
def predict(mdoel,device,valid_loader):
    model.eval()
    cnt = 0
    with torch.no_grad():
        for batch_idx,(data,label) in enumerate(valid_loader):
            data, label = data.to(device), label.to(device)
            output = model(data)
            output = F.log_softmax(output, dim=1)
 
            data = data.squeeze().cpu().numpy().transpose((1, 2, 0)) * 255
            data = data[:,:,::-1]
            pred = output.max(dim=1)[1].squeeze().cpu().numpy()
            cm = np.array(colormap).astype('uint8')
            label = label.squeeze().cpu().numpy()
            label = cm[label][:,:,::-1]#becuse opencv channel is bgr
            pred = cm[pred][:,:,::-1]
 
            cv2.imwrite('./result/resnet/%d_img.jpg' % batch_idx, data)
            cv2.imwrite('./result/resnet/%d_label.jpg'%batch_idx,label)
            cv2.imwrite('./result/resnet/%d_pred.png'%batch_idx,pred)
            cnt+=1
            if cnt>6:
                break
 
 
def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 50 epochs"""
    lr = 0.01 * (0.1 ** (epoch // 50))
    for param_group in optimizer.param_groups:
        print('lr : ',param_group['lr'])
        param_group['lr'] = lr
        print('reduce to ',param_group['lr'])
 
# for epoch in range(epochs):
#     print(epoch)
#     adjust_learning_rate(optimizer,epoch)
#     train(model,device,train_loader,optimizer,epoch)
#     torch.save(model,'./model/fcnx8resnet50_1.pkl')
 
predict(model,device,train_loader)
 

fcn.py

'''
@Author  :   {AishuaiYao}
@License :   (C) Copyright 2013-2020, {None}
@Contact :   {[email protected]}
@Software:   ${segmentation}
@File    :   ${fcn}.py
@Time    :   ${2020-04-06}
@Desc    :   deconvlution experiment
'''
 
import torch
import torch.nn as nn
import numpy as np
import torchvision.models as models
#
pretrained_net = models.vgg16(pretrained=True)
# a = pretrained_net.children()
# n = list(pretrained_net.children())[0]
# # d = len(n._modules)
# # b = n._modules['0']
# # c = nn.Conv2d(3,64,3,1,0,bias=False)
# e = []
# for i in range(31):
#     e.append(n._modules[str(i)])
#
# f = nn.Sequential(*e)
#
# #
# #
# # l = nn.Sequential(*list(pretrained_net.children())[:-1])
# print('x')
 
 
class FCNx8_VGG(nn.Module):
    def __init__(self,num_classes):
        super(FCN_8s_VGG, self).__init__()
        conv_sequential= list(pretrained_net.children())[0]
        modules_list = []
        for i in range(17):
            modules_list.append(conv_sequential._modules[str(i)])
        self.stage1 = nn.Sequential(*modules_list)
 
        modules_list = []
        for i in range(17,24):
            modules_list.append(conv_sequential._modules[str(i)])
        self.stage2 = nn.Sequential(*modules_list)
 
        modules_list = []
        for i in range(24,31):
            modules_list.append(conv_sequential._modules[str(i)])
        modules_list.append(nn.Conv2d(in_channels=512,out_channels=4096,kernel_size=1,stride=1,padding=0))
        modules_list.append(nn.Conv2d(in_channels=4096,out_channels=4096,kernel_size=1,stride=1,padding=0))
        self.stage3 = nn.Sequential(*modules_list)
 
        self.scores3 = nn.Conv2d(in_channels=4096,out_channels=num_classes,kernel_size=1)
        self.scores2 = nn.Conv2d(in_channels=512,out_channels=num_classes,kernel_size=1)
        self.scores1 = nn.Conv2d(in_channels=256,out_channels=num_classes,kernel_size=1)
 
        # N=(w-1)��s+k-2p
        self.upsample_8x = nn.ConvTranspose2d(in_channels=num_classes,out_channels=num_classes,kernel_size=16,stride=8,padding=4,bias= False)
        self.upsample_8x.weight.data = self.bilinear_kernel(num_classes,num_classes,16)
        self.upsample_16x = nn.ConvTranspose2d(in_channels=num_classes,out_channels=num_classes,kernel_size=4,stride=2,padding=1,bias=False)
        self.upsample_16x.weight.data = self.bilinear_kernel(num_classes,num_classes,4)
        self.upsample_32x = nn.ConvTranspose2d(in_channels=num_classes,out_channels=num_classes,kernel_size=4,stride=2,padding=1,bias=False)
        self.upsample_32x.weight.data = self.bilinear_kernel(num_classes,num_classes,4)
 
    def forward(self, x):
        x = self.stage1(x)
        s1 = x
 
        x = self.stage2(x)
        s2 = x
 
        x = self.stage3(x)
        s3 = x
 
        s3 = self.scores3(s3)
        s3 = self.upsample_32x(s3)
 
        s2 = self.scores2(s2)
        s2 = s2 + s3
        s2 = self.upsample_16x(s2)
 
        s1 = self.scores1(s1)
        s = s1 + s2
        s = self.upsample_8x(s)
 
        return s
 
    def bilinear_kernel(self,in_channels, out_channels, kernel_size):
        '''
        return a bilinear filter tensor
        '''
        factor = (kernel_size + 1) // 2
        if kernel_size % 2 == 1:
            center = factor - 1
        else:
            center = factor - 0.5
        og = np.ogrid[:kernel_size, :kernel_size]
        filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
        weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype='float32')
        weight[range(in_channels), range(out_channels), :, :] = filt
        return torch.from_numpy(weight)
 
 
class FCNx8_ResNet(nn.Module):
    def __init__(self,num_classes):
        super(FCNx8_ResNet, self).__init__()
        pretrained_net = models.resnet50(pretrained=True)
        conv_sequential= list(pretrained_net.children())[:-1]
 
        modules_list = []
        for i in range(4):
            modules_list.append(conv_sequential[i])
        self.head = nn.Sequential(*modules_list)
 
        modules_list = []
        for i in range(4,6):
            temp = list(conv_sequential[i])
            for j in range(len(temp)):
                modules_list.append(temp[j])
        self.stage1 = nn.Sequential(*modules_list)
 
        modules_list = []
        temp = list(conv_sequential[6])
        for j in range(len(temp)):
            modules_list.append(temp[j])
        self.stage2 = nn.Sequential(*modules_list)
 
        modules_list = []
        temp = list(conv_sequential[7])
        for j in range(len(temp)):
            modules_list.append(temp[j])
        modules_list.append(conv_sequential[8])
        modules_list.append(nn.Conv2d(in_channels=2048,out_channels=1024,kernel_size=1,stride=1,padding=0))
        modules_list.append(nn.Conv2d(in_channels=1024,out_channels=512,kernel_size=1,stride=1,padding=0))
        self.stage3 = nn.Sequential(*modules_list)
 
        self.scores3 = nn.Conv2d(in_channels=512,out_channels=num_classes,kernel_size=1)
        self.scores2 = nn.Conv2d(in_channels=1024,out_channels=num_classes,kernel_size=1)
        self.scores1 = nn.Conv2d(in_channels=512,out_channels=num_classes,kernel_size=1)
        #
        # # N=(w-1)xs+k-2p
        self.upsamplex8 = nn.ConvTranspose2d(in_channels=num_classes,out_channels=num_classes,kernel_size=16,stride=8,padding=4,bias= False)
        self.upsamplex8.weight.data = self.bilinear_kernel(num_classes,num_classes,16)
        self.upsamplex16 = nn.ConvTranspose2d(in_channels=num_classes,out_channels=num_classes,kernel_size=4,stride=2,padding=1,bias=False)
        self.upsamplex16.weight.data = self.bilinear_kernel(num_classes,num_classes,4)
        self.upsamplex32= nn.ConvTranspose2d(in_channels=num_classes,out_channels=num_classes,kernel_size=5,stride=3,padding=0,bias=False)
        self.upsamplex32.weight.data = self.bilinear_kernel(num_classes,num_classes,5)
 
    def forward(self, x):
        x = self.head(x)
        x = self.stage1(x)
        s1 = x
 
        x = self.stage2(x)
        s2 = x
 
        x = self.stage3(x)
        s3 = x
 
        s3 = self.scores3(s3)
        s3 = self.upsamplex32(s3)
 
        s2 = self.scores2(s2)
        s2 = s2 + s3
        s2 = self.upsamplex16(s2)
 
        s1 = self.scores1(s1)
        s = s1 + s2
        s = self.upsamplex8(s)
 
        return s
 
    def bilinear_kernel(self,in_channels, out_channels, kernel_size):
        '''
        return a bilinear filter tensor
        '''
        factor = (kernel_size + 1) // 2
        if kernel_size % 2 == 1:
            center = factor - 1
        else:
            center = factor - 0.5
        og = np.ogrid[:kernel_size, :kernel_size]
        filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
        weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype='float32')
        weight[range(in_channels), range(out_channels), :, :] = filt
        return torch.from_numpy(weight)

從左到右依次爲 原圖 標籤 fcnx8_resnet50 fcnx8_vgg

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章