PyTorch 深度學習:37分鐘快速入門——FCN 做語義分割

語義分割是一種像素級別的處理圖像方式,對比於目標檢測其更加精確,能夠自動從圖像中劃分出對象區域並識別對象區域中的類別

在 2015 年 CVPR 的一篇論文 Fully Convolutional Networks for Semantic Segmentation 這篇文章提出了全卷積的概念,第一次將端到端的卷積網絡推廣到了語義分割的任務當中,隨後出現了很多基於 FCN 實現的網絡結構,比如 U-Net 等。

數據集 首先我們需要下載數據集,這裏我們使用 PASCAL VOC 數據集,其是一個正在進行的目標檢測,目標識別,語義分割的挑戰,我們可以進行數據集的下載 下載完成數據集之後進行解壓,我們可以再 ImageSets/Segmentation/train.txt 和 ImageSets/Segmentation/val.txt 中找到我們的訓練集和驗證集的數據,圖片存放在 /JPEGImages 中,後綴是 .jpg,而 label 存放在 /SegmentationClass 中,後綴是 .png 我們可以可視化一下

 

# 導入需要的包
import os
import torch
import numpy as np
from torch.autograd import Variable
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from mxtorch import transforms as tfs
from datetime import datetime

import matplotlib.pyplot as plt
%matplotlib inline

im_show1 = Image.open('./dataset/VOCdevkit/VOC2012/JPEGImages/2007_005210.jpg') label_show1 = Image.open('./dataset/VOCdevkit/VOC2012/SegmentationClass/2007_005210.png').convert('RGB') im_show2 = Image.open('./dataset/VOCdevkit/VOC2012/JPEGImages/2007_000645.jpg') label_show2 = Image.open('./dataset/VOCdevkit/VOC2012/SegmentationClass/2007_000645.png').convert('RGB')

_, figs = plt.subplots(2, 2, figsize=(10, 8))
figs[0][0].imshow(im_show1)
figs[0][0].axes.get_xaxis().set_visible(False)
figs[0][0].axes.get_yaxis().set_visible(False)
figs[0][1].imshow(label_show1)
figs[0][1].axes.get_xaxis().set_visible(False)
figs[0][1].axes.get_yaxis().set_visible(False)
figs[1][0].imshow(im_show2)
figs[1][0].axes.get_xaxis().set_visible(False)
figs[1][0].axes.get_yaxis().set_visible(False)
figs[1][1].imshow(label_show2)
figs[1][1].axes.get_xaxis().set_visible(False)
figs[1][1].axes.get_yaxis().set_visible(False)
print(im_show1.size)
print(im_show2.size)

首先輸出圖片的大小,左邊就是真實的圖片,右邊就是分割之後的結果

然後我們定義一個函數進行圖片的讀入,根據 `train.txt` 和 `val.txt` 中的文件名進行圖片讀入,我們不需要這一步就讀入圖片,只需要知道圖片的路徑,之後根據圖片名稱生成 batch 的時候再讀入圖片,並做一些數據預處理

voc_root = './dataset/VOCdevkit/VOC2012'

def read_images(root=voc_root, train=True):
    txt_fname = root + '/ImageSets/Segmentation/' + ('train.txt' if train else 'val.txt')
    with open(txt_fname, 'r') as f:
        images = f.read().split()
    data = [os.path.join(root, 'JPEGImages', i+'.jpg') for i in images]
    label = [os.path.join(root, 'SegmentationClass', i+'.png') for i in images]
    return data, label

可能你已經注意到了前面展示的兩張圖片的大小是不一樣的,如果我們要使用一個 batch 進行計算,我們需要圖片的大小保持一致,在前面使用卷積網絡進行圖片分類的任務中,我們通過 resize 的辦法對圖片進行了縮放,使得他們的大小相同,但是這裏會遇到一個問題,對於輸入圖片我們當然可以 resize 成任意我們想要的大小,但是 label 也是一張圖片,且是在 pixel 級別上的標註,所以我們沒有辦法對 label 進行有效的 resize 似的其也能達到像素級別的匹配,所以爲了使得輸入的圖片大小相同,我們就使用 crop 的方式來解決這個問題,也就是從一張圖片中 crop 出固定大小的區域,然後在 label 上也做同樣方式的 crop。
使用 crop 可以使用 pytorch 中自帶的 transforms,不過要稍微改一下,不僅輸出 crop 出來的區域,同時還要輸出對應的座標便於我們在 label 上做相同的 crop

def random_crop(data, label, crop_size):
    height, width = crop_size
    data, rect = tfs.RandomCrop((height, width))(data)
    label = tfs.FixedCrop(*rect)(label)
    return data, label

 下面我們可以驗證一下隨機 crop

_, figs = plt.subplots(2, 2, figsize=(10, 8))
crop_im1, crop_label1 = random_crop(im_show1, label_show1, (200, 300))
figs[0][0].imshow(crop_im1)
figs[0][1].imshow(crop_label1)
figs[0][0].axes.get_xaxis().set_visible(False)
figs[0][1].axes.get_yaxis().set_visible(False)
crop_im2, crop_label2 = random_crop(im_show1, label_show1, (200, 300))
figs[1][0].imshow(crop_im2)
figs[1][1].imshow(crop_label2)
figs[1][0].axes.get_xaxis().set_visible(False)
figs[1][1].axes.get_yaxis().set_visible(False)

上面就是我們做兩次隨機 crop 的結果,可以看到圖像和 label 能夠完美的對應起來 接着我們根據數據知道里面有 21 中類別,同時給出每種類別對應的 RGB 值

classes = ['background','aeroplane','bicycle','bird','boat',
           'bottle','bus','car','cat','chair','cow','diningtable',
           'dog','horse','motorbike','person','potted plant',
           'sheep','sofa','train','tv/monitor']

# RGB color for each class
colormap = [[0,0,0],[128,0,0],[0,128,0], [128,128,0], [0,0,128],
            [128,0,128],[0,128,128],[128,128,128],[64,0,0],[192,0,0],
            [64,128,0],[192,128,0],[64,0,128],[192,0,128],
            [64,128,128],[192,128,128],[0,64,0],[128,64,0],
            [0,192,0],[128,192,0],[0,64,128]]

len(classes), len(colormap)

接着可以建立一個索引,也就是將一個類別的 RGB 值對應到一個整數上,通過這種一一對應的關係,能夠將 label 圖片變成一個矩陣,矩陣和原圖片一樣大,但是隻有一個通道數,也就是 (h, w) 這種大小,裏面的每個數值代表着像素的類別

cm2lbl = np.zeros(256**3) # 每個像素點有 0 ~ 255 的選擇,RGB 三個通道
for i,cm in enumerate(colormap):
    cm2lbl[(cm[0]*256+cm[1])*256+cm[2]] = i # 建立索引

def image2label(img):
    data = np.array(img, dtype='int32')
    idx = (data[:, :, 0] * 256 + data[:, :, 1]) * 256 + data[:, :, 2]
    return np.array(cm2lbl[idx], dtype='int64') # 根據索引得到 label 矩陣

定義完成之後,我們可以驗證一下

label_im = Image.open('./dataset/VOCdevkit/VOC2012/SegmentationClass/2007_000033.png').convert('RGB') label_im

接着我們可以定義數據預處理方式,之前我們讀取的數據只有文件名,現在我們開始做預處理,非常簡單,首先隨機 crop 出固定大小的區域,然後使用 ImageNet 的均值和方差做標準化。

def img_transforms(img, label, crop_size):
    img, label = random_crop(img, label, crop_size)
    img_tfs = tfs.Compose([
        tfs.ToTensor(),
        tfs.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    img = img_tfs(img)
    label = image2label(label)
    label = torch.from_numpy(label)
    return img, label

定義一個 COVSegDataset 繼承於 torch.utils.data.Dataset 構成我們自定的訓練集

class VOCSegDataset(Dataset):
    '''
    voc dataset
    '''
    def __init__(self, train, crop_size, transforms):
        self.crop_size = crop_size
        self.transforms = transforms
        data_list, label_list = read_images(train=train)
        self.data_list = self._filter(data_list)
        self.label_list = self._filter(label_list)
        print('Read ' + str(len(self.data_list)) + ' images')
        
    def _filter(self, images): # 過濾掉圖片大小小於 crop 大小的圖片
        return [im for im in images if (Image.open(im).size[1] >= self.crop_size[0] and 
                                        Image.open(im).size[0] >= self.crop_size[1])]
        
    def __getitem__(self, idx):
        img = self.data_list[idx]
        label = self.label_list[idx]
        img = Image.open(img)
        label = Image.open(label).convert('RGB')
        img, label = self.transforms(img, label, self.crop_size)
        return img, label
    
    def __len__(self):
        return len(self.data_list)

# 實例化數據集
input_shape = (320, 480)
voc_train = VOCSegDataset(True, input_shape, img_transforms)
voc_test = VOCSegDataset(False, input_shape, img_transforms)

train_data = DataLoader(voc_train, 64, shuffle=True, num_workers=4)
valid_data = DataLoader(voc_test, 128, num_workers=4)

在 pytorch 中轉置卷積可以使用 torch.nn.ConvTranspose2d() 來實現,下面我們舉個例子

x = torch.randn(1, 3, 120, 120)
conv_trans = nn.ConvTranspose2d(3, 10, 4, stride=2, padding=1)
y = conv_trans(Variable(x))
print(y.shape)

torch.Size([1, 10, 240, 240])  可以看到輸出變成了輸入的 2 倍

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章