【庖丁解牛】從零實現RetinaNet(一):COCO與VOC數據集處理

所有代碼已上傳到本人github repository:https://github.com/zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training
如果覺得有用,請點個star喲!
下列代碼均在pytorch1.4版本中測試過,確認正確無誤。

前言

經過前面的base model系列ImageNet訓練實踐,筆者終於要開始學習目標檢測了。目標檢測這塊的細節特別多,而這些細節在論文中通常不會提及(往往是繼承以前的目標檢測器的做法),因此只有在代碼中才能更好的瞭解這些細節。學習的最好方法就是自己實現一個目標檢測器。在本系列中,筆者將從零開始實現單階段目標檢測器RetinaNet,包含數據集處理、數據增強、網絡結構、loss、decode等部分。

COCO數據集介紹

COCO數據集官方網站地址:http://cocodataset.org/#home 。COCO是一個大規模目標檢測數據集。COCO數據集每年都會更新,但是在目標檢測論文中我們只會用到COCO2014與COCO2017數據集。COCO2017數據集包括三個子集:train(118287張圖片)、val(5000張圖片)、test(40670張圖片),共有80個類。其中train和val集都提供了ground truth,test集沒有ground turth,需要把detect結果提交到COCO數據集官網上測試才能得到結果。

COCO2014與COCO2017數據集的區別?
在RetinaNet論文中提供的Detectron開源代碼(https://github.com/facebookresearch/Detectron/blob/master/MODEL_ZOO.md)中我們可以找到相關解釋:RetinaNet論文中所有模型都是在coco_2014_train數據集(82783張圖片)和coco_2014_valminusminival數據集(共有40504張圖片)隨機劃分的含35504張圖片的子集的並集上進行訓練,這個並集實際上與coco_2017_train數據集完全一致。在測試時,所有模型都在coco_2014_minival數據集剩下的含有5000張圖片的另一個子集上進行測試,這個子集實際上與coco_2017_val數據集完全一致。也就是說,RetinaNet論文中實際上就是用coco_2017_train數據集訓練模型,用coco_2017_val數據集測試模型。
在RetinaNet論文中,模型的表現指IoU=0.5:0.95下,最多保留100個detect目標,保留所有大小的目標下的mAP(即pycocotools.cocoeval的COCOeval類中_summarizeDets函數中的stats[0]值)。

模型在val數據集上和test數據集上的表現差多少?
由於train、val、test集實際上都是從同一個母數據集隨機劃分成三部分得到的,模型在val集和test集中的表現差距很小。根據其他有同時在val和test上測試模型的論文中給出的結果,一般在val和test集上模型的mAP相差在0.2-0.3個百分點左右。

在接下來的復現中,我們遵循RetinaNet論文中的數據集設置,使用coco_2017_train數據集訓練模型,使用coco_2017_val數據集測試模型。使用IoU=0.5:0.95下,最多保留100個detect目標,保留所有大小的目標下的mAP(即pycocotools.cocoeval的COCOeval類中_summarizeDets函數中的stats[0]值)作爲模型的性能表現。

VOC數據集介紹

VOC數據集官方網站地址:http://host.robots.ox.ac.uk/pascal/VOC/ 。VOC也是一個目標檢測數據集,但規模要比COCO數據集小的多。在目標檢測論文中我們通常用VOC2007和VOC2012。和COCO數據集一樣,VOC2007和VOC2012都分爲train、val、test三個子集,共有20個類。對VOC2007,train、val、test三個子集都提供了ground truth。對VOC2012,只有train、val兩個子集提供了ground truth。

我們參照detectron2中使用faster rcnn在VOC數據集上訓練測試的做法(https://github.com/facebookresearch/detectron2/blob/master/MODEL_ZOO.md),使用VOC2007trainval+VOC2012trainval數據集訓練模型,使用VOC2007test數據集測試模型。測試時使用VOC2007的11 point metric方式計算mAP。

COCO和VOC數據集文件組織結構

我們下載好COCO數據集和VOC數據集後,將文件夾組織結構調整成下面這樣:

COCO2017
|
|
|----annotations----contains all annotaion json files
|
|                  |----train2017
|----images--------|----val2017
                   |----test2017

VOCdataset
|
|
|                  |----Annotations
|                  |----ImageSets
|----VOC2007-------|----JPEGImages
|                  |----SegmentationClass
|                  |----SegmentationObject
|
|                  |----Annotations
|                  |----ImageSets
|----VOC2012-------|----JPEGImages
|                  |----SegmentationClass
|                  |----SegmentationObject

COCO數據集處理

COCO2017數據集標註中提供的原始box座標是[x_min,y_min,w,h],即框左上角座標和框的寬高,我們會將這個box座標轉換爲[x_min,y_min,x_max,y_max],即框左上角座標和框右下角座標。同時,標註中也提供了類別index,但是原始標註的類別index不連續(1-90,但是隻有80個類),我們要將其轉換成連續的類別index0-79。
處理COCO數據集的代碼如下:

import os
import cv2
import torch
import numpy as np
import random
from torch.utils.data import Dataset
from pycocotools.coco import COCO
import torch.nn.functional as F

COCO_CLASSES = [
    "person",
    "bicycle",
    "car",
    "motorcycle",
    "airplane",
    "bus",
    "train",
    "truck",
    "boat",
    "traffic light",
    "fire hydrant",
    "stop sign",
    "parking meter",
    "bench",
    "bird",
    "cat",
    "dog",
    "horse",
    "sheep",
    "cow",
    "elephant",
    "bear",
    "zebra",
    "giraffe",
    "backpack",
    "umbrella",
    "handbag",
    "tie",
    "suitcase",
    "frisbee",
    "skis",
    "snowboard",
    "sports ball",
    "kite",
    "baseball bat",
    "baseball glove",
    "skateboard",
    "surfboard",
    "tennis racket",
    "bottle",
    "wine glass",
    "cup",
    "fork",
    "knife",
    "spoon",
    "bowl",
    "banana",
    "apple",
    "sandwich",
    "orange",
    "broccoli",
    "carrot",
    "hot dog",
    "pizza",
    "donut",
    "cake",
    "chair",
    "couch",
    "potted plant",
    "bed",
    "dining table",
    "toilet",
    "tv",
    "laptop",
    "mouse",
    "remote",
    "keyboard",
    "cell phone",
    "microwave",
    "oven",
    "toaster",
    "sink",
    "refrigerator",
    "book",
    "clock",
    "vase",
    "scissors",
    "teddy bear",
    "hair drier",
    "toothbrush",
]

colors = [
    (39, 129, 113),
    (164, 80, 133),
    (83, 122, 114),
    (99, 81, 172),
    (95, 56, 104),
    (37, 84, 86),
    (14, 89, 122),
    (80, 7, 65),
    (10, 102, 25),
    (90, 185, 109),
    (106, 110, 132),
    (169, 158, 85),
    (188, 185, 26),
    (103, 1, 17),
    (82, 144, 81),
    (92, 7, 184),
    (49, 81, 155),
    (179, 177, 69),
    (93, 187, 158),
    (13, 39, 73),
    (12, 50, 60),
    (16, 179, 33),
    (112, 69, 165),
    (15, 139, 63),
    (33, 191, 159),
    (182, 173, 32),
    (34, 113, 133),
    (90, 135, 34),
    (53, 34, 86),
    (141, 35, 190),
    (6, 171, 8),
    (118, 76, 112),
    (89, 60, 55),
    (15, 54, 88),
    (112, 75, 181),
    (42, 147, 38),
    (138, 52, 63),
    (128, 65, 149),
    (106, 103, 24),
    (168, 33, 45),
    (28, 136, 135),
    (86, 91, 108),
    (52, 11, 76),
    (142, 6, 189),
    (57, 81, 168),
    (55, 19, 148),
    (182, 101, 89),
    (44, 65, 179),
    (1, 33, 26),
    (122, 164, 26),
    (70, 63, 134),
    (137, 106, 82),
    (120, 118, 52),
    (129, 74, 42),
    (182, 147, 112),
    (22, 157, 50),
    (56, 50, 20),
    (2, 22, 177),
    (156, 100, 106),
    (21, 35, 42),
    (13, 8, 121),
    (142, 92, 28),
    (45, 118, 33),
    (105, 118, 30),
    (7, 185, 124),
    (46, 34, 146),
    (105, 184, 169),
    (22, 18, 5),
    (147, 71, 73),
    (181, 64, 91),
    (31, 39, 184),
    (164, 179, 33),
    (96, 50, 18),
    (95, 15, 106),
    (113, 68, 54),
    (136, 116, 112),
    (119, 139, 130),
    (31, 139, 34),
    (66, 6, 127),
    (62, 39, 2),
    (49, 99, 180),
    (49, 119, 155),
    (153, 50, 183),
    (125, 38, 3),
    (129, 87, 143),
    (49, 87, 40),
    (128, 62, 120),
    (73, 85, 148),
    (28, 144, 118),
    (29, 9, 24),
    (175, 45, 108),
    (81, 175, 64),
    (178, 19, 157),
    (74, 188, 190),
    (18, 114, 2),
    (62, 128, 96),
    (21, 3, 150),
    (0, 6, 95),
    (2, 20, 184),
    (122, 37, 185),
]


class CocoDetection(Dataset):
    def __init__(self,
                 image_root_dir,
                 annotation_root_dir,
                 set='train2017',
                 transform=None):
        self.image_root_dir = image_root_dir
        self.annotation_root_dir = annotation_root_dir
        self.set_name = set
        self.transform = transform

        self.coco = COCO(
            os.path.join(self.annotation_root_dir,
                         'instances_' + self.set_name + '.json'))

        self.load_classes()

    def load_classes(self):
        self.image_ids = self.coco.getImgIds()
        self.cat_ids = self.coco.getCatIds()
        self.categories = self.coco.loadCats(self.cat_ids)
        self.categories.sort(key=lambda x: x['id'])

        # category_id is an original id,coco_id is set from 0 to 79
        self.category_id_to_coco_label = {
            category['id']: i
            for i, category in enumerate(self.categories)
        }
        self.coco_label_to_category_id = {
            v: k
            for k, v in self.category_id_to_coco_label.items()
        }

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img = self.load_image(idx)
        annot = self.load_annotations(idx)

        sample = {'img': img, 'annot': annot, 'scale': 1.}
        if self.transform:
            sample = self.transform(sample)
        return sample

    def load_image(self, image_index):
        image_info = self.coco.loadImgs(self.image_ids[image_index])[0]
        path = os.path.join(self.image_root_dir, image_info['file_name'])
        img = cv2.imread(path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        return img.astype(np.float32) / 255.

    def load_annotations(self, image_index):
        # get ground truth annotations
        annotations_ids = self.coco.getAnnIds(
            imgIds=self.image_ids[image_index], iscrowd=None)
        annotations = np.zeros((0, 5))

        # some images appear to miss annotations
        if len(annotations_ids) == 0:
            return annotations

        # parse annotations
        coco_annotations = self.coco.loadAnns(annotations_ids)
        for _, a in enumerate(coco_annotations):
            # some annotations have basically no width / height, skip them
            if a['bbox'][2] < 1 or a['bbox'][3] < 1:
                continue

            annotation = np.zeros((1, 5))
            annotation[0, :4] = a['bbox']
            annotation[0, 4] = self.find_coco_label_from_category_id(
                a['category_id'])

            annotations = np.append(annotations, annotation, axis=0)

        # transform from [x_min, y_min, w, h] to [x_min, y_min, x_max, y_max]
        annotations[:, 2] = annotations[:, 0] + annotations[:, 2]
        annotations[:, 3] = annotations[:, 1] + annotations[:, 3]

        return annotations

    def find_coco_label_from_category_id(self, category_id):
        return self.category_id_to_coco_label[category_id]

    def find_category_id_from_coco_label(self, coco_label):
        return self.coco_label_to_category_id[coco_label]

    def num_classes(self):
        return 80

    def image_aspect_ratio(self, image_index):
        image = self.coco.loadImgs(self.image_ids[image_index])[0]
        return float(image['width']) / float(image['height'])

該類遍歷的每一個對象就是一張圖片的相關信息(在一個字典裏),鍵’img’對應的值就是圖片,鍵’annot’對應的numpy數組就是這張圖片標註的對象。注意每張圖片標註的對象數量不一定一樣,也有可能某張圖片沒有標註對象。

VOC數據集處理

VOC數據集標註中提供的原始box座標就是[x_min,y_min,x_max,y_max],因此不需要轉換座標。標註中只提供了類別的name,我們要將其映射爲類別index0-19。
處理VOC數據集的代碼如下:

import os
import cv2
import numpy as np
import random
import xml.etree.ElementTree as ET

import torch
from torch.utils.data import Dataset

VOC_CLASSES = [
    "aeroplane",
    "bicycle",
    "bird",
    "boat",
    "bottle",
    "bus",
    "car",
    "cat",
    "chair",
    "cow",
    "diningtable",
    "dog",
    "horse",
    "motorbike",
    "person",
    "pottedplant",
    "sheep",
    "sofa",
    "train",
    "tvmonitor",
]

colors = [
    (39, 129, 113),
    (164, 80, 133),
    (83, 122, 114),
    (99, 81, 172),
    (95, 56, 104),
    (37, 84, 86),
    (14, 89, 122),
    (80, 7, 65),
    (10, 102, 25),
    (90, 185, 109),
    (106, 110, 132),
    (169, 158, 85),
    (188, 185, 26),
    (103, 1, 17),
    (82, 144, 81),
    (92, 7, 184),
    (49, 81, 155),
    (179, 177, 69),
    (93, 187, 158),
    (13, 39, 73),
]


class VocDetection(Dataset):
    def __init__(self,
                 root_dir,
                 image_sets=[('2007', 'trainval'), ('2012', 'trainval')],
                 transform=None,
                 keep_difficult=False):
        self.root_dir = root_dir
        self.image_set = image_sets
        self.transform = transform
        self.categories = VOC_CLASSES

        self.category_id_to_voc_label = dict(
            zip(self.categories, range(len(self.categories))))
        self.voc_label_to_category_id = {
            v: k
            for k, v in self.category_id_to_voc_label.items()
        }

        self.keep_difficult = keep_difficult

        self._annopath = os.path.join('%s', 'Annotations', '%s.xml')
        self._imgpath = os.path.join('%s', 'JPEGImages', '%s.jpg')
        self.ids = list()
        for (year, name) in image_sets:
            rootpath = os.path.join(self.root_dir, 'VOC' + year)
            for line in open(
                    os.path.join(rootpath, 'ImageSets', 'Main',
                                 name + '.txt')):
                self.ids.append((rootpath, line.strip()))

    def __getitem__(self, idx):
        img_id = self.ids[idx]
        img = self.load_image(img_id)

        target = ET.parse(self._annopath % img_id).getroot()
        annot = self.load_annotations(target)

        sample = {'img': img, 'annot': annot, 'scale': 1.}

        if self.transform:
            sample = self.transform(sample)
        return sample

    def load_image(self, img_id):
        img = cv2.imread(self._imgpath % img_id)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        return img.astype(np.float32) / 255.

    def load_annotations(self, target):
        annotations = []
        for obj in target.iter('object'):
            difficult = int(obj.find('difficult').text) == 1
            if not self.keep_difficult and difficult:
                continue
            name = obj.find('name').text.lower().strip()
            bbox = obj.find('bndbox')

            pts = ['xmin', 'ymin', 'xmax', 'ymax']

            bndbox = []
            for pt in pts:
                cur_pt = float(bbox.find(pt).text)
                bndbox.append(cur_pt)
            label_idx = self.category_id_to_voc_label[name]
            bndbox.append(label_idx)
            annotations += [bndbox]  # [xmin, ymin, xmax, ymax, label_ind]
            # img_id = target.find('filename').text[:-4]

        annotations = np.array(annotations)
        # format:[[x1, y1, x2, y2, label_ind], ... ]
        return annotations

    def find_category_id_from_voc_label(self, voc_label):
        return self.voc_label_to_category_id[voc_label]

    def image_aspect_ratio(self, idx):
        img_id = self.ids[idx]
        image = self.load_image(img_id)
        #w/h
        return float(image.shape[1]) / float(image.shape[0])

    def __len__(self):
        return len(self.ids)

和COCO類類似,該類遍歷的每一個對象就是一張圖片的相關信息(在一個字典裏),鍵’img’對應的值就是圖片,鍵’annot’對應的numpy數組就是這張圖片標註的對象。注意每張圖片標註的對象數量不一定一樣,也有可能某張圖片沒有標註對象。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章