mmdetection源碼筆記(三):創建數據集模型之datasets/coco.py的解讀(上)

引言

同樣,數據集也是需要build模型的。類CocoDataset是CustomDataset類的子類,而CustomDataset是Dataset的子類。(之前的創建模型,都是torch.nn.module的子類,數據集的創建就不是咯,注意一下)
關於CustomDataset的定義和其類方法的代碼解讀,可以看下面這篇文章:

類CocoDataset依然作爲形參,添加到@DATASETS.register_module中,作用就是將其保存到註冊表的module中。共有四個方法:

  • load_annotations():加載標註文件中的annotation字典,返回圖片信息,比如:info{"filename":"284193,faa9000f2678b5e.jpg"}
  • get_ann_info(self,idx):獲得annotation的信息,其實是調用了_parse_ann_info();它的形參是指定的圖片id,返回值是個字典:bboxes,bboxes_ignore, labels, masks, mask_polys, poly_lens.
  • _filter_imgs(self, min_size):過濾圖片,去除沒有annotation標註文件的圖片,以及圖片尺寸小於min-size的圖片。
  • _parse_ann_info(self, ann_info, with_mask=True):解析一張圖片的annotation的信息,主要是bbox和mask信息,返回值爲:bboxes,bboxes_ignore, labels, masks, mask_polys, poly_lens.(如果沒用mask分支,就沒用後面的三個返回值了)

在父類中custom.py有其初始化,還有另外的幾個重要的函數,比如prepare_train_img()、prepare_test_img()等。因爲其代碼行數太長,所以在上面的鏈接裏寫一篇講解。
以下是coco.py的代碼,如有錯誤的地方,還請指出,後面博主也會繼續修改,增加對各個代碼段的理解。

coco.py代碼註釋

import numpy as np
from pycocotools.coco import COCO
from .custom import CustomDataset
from .registry import DATASETS
@DATASETS.register_module
class CocoDataset(CustomDataset):
""" coco api
self.coco = COCO(ann_file)
create class members
		# 一維數組,值爲對應原coco數據集的annotation/images/categories信息
        coco.anns = anns  
        coco.imgs = imgs
        coco.cats = cats
        # 兩個默認value爲list的字典 比如:imgToAnns{"1":[ann1,ann2,ann3,....]},'1'爲image_id = 1,ann1時其爲 1 的annotation。
        coco.imgToAnns = imgToAnns
        coco.catToImgs = catToImgs
        以上五個members,在實例化COCO時,被創建
"""
    CLASSES = ('person', 'mask')
    def load_annotations(self, ann_file):
        self.coco = COCO(ann_file)    
        
        #函數 getImgIds()、getCatIds()、getAnnIds(),返回值爲integer array of img/cat/ann ids,形參爲過濾條件
        self.cat_ids = self.coco.getCatIds()  # [1,2] , integer array of cat ids
        self.cat2label = {    #  dict
            cat_id: i + 1     #  cat_id : cat_id + 1 ??
            for i, cat_id in enumerate(self.cat_ids) # enumerate 遍歷 , i 從0 開始,是取其索引的意思。cat_id就是索引下的值。
        }
        self.img_ids = self.coco.getImgIds()
        img_infos = []
        for i in self.img_ids:
            info = self.coco.loadImgs([i])[0]   
            info['filename'] = info['file_name']
            img_infos.append(info)
        return img_infos                        
        # 一個例子:比如info{'file_name': '273278,e118d000ec53d5cd.jpg', 'height': 1365, 'width': 2048, 'id': 4370, 'filename': '273278,e118d000ec53d5cd.jpg'} 

    # 指定id,獲得該圖片的標註信息。{bboxes, bboxes_ignore,labels, masks, mask_polys, poly_lens}
    def get_ann_info(self, idx):          
        img_id = self.img_infos[idx]['id']
        ann_ids = self.coco.getAnnIds(imgIds=[img_id])          
        ann_info = self.coco.loadAnns(ann_ids)                  # loadAnns()  return ann or [ann,ann,ann],這裏時返回單個ann數據
        return self._parse_ann_info(ann_info, self.with_mask)   # 調用 _parse_ann_info(), return dict {bboxes, bboxes_ignore,labels, masks, mask_polys, poly_lens}

    def _filter_imgs(self, min_size=32):
        valid_inds = []
        ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())  #獲得ann標註文件裏的image_id的值
        for i, img_info in enumerate(self.img_infos):
            if self.img_ids[i] not in ids_with_ann:     #去除無標註文件的圖片
                continue
            if min(img_info['width'], img_info['height']) >= min_size: # 去除小圖片,尺寸小於32的圖片不要。
                valid_inds.append(i)
        return valid_inds

    def _parse_ann_info(self, ann_info, with_mask=True):  
        """Parse bbox and mask annotation.
        Args:
            ann_info (list[dict]): Annotation info of an image.
            with_mask (bool): Whether to parse mask annotations.
        Returns:
            dict: A dict containing the following keys: bboxes, bboxes_ignore,
                labels, masks, mask_polys, poly_lens.
        """
        gt_bboxes = []
        gt_labels = []
        gt_bboxes_ignore = []
        # Two formats are provided.
        # 1. mask: a binary map of the same size of the image.
        # 2. polys: each mask consists of one or several polys, each poly is a
        # list of float.
        if with_mask:
            gt_masks = []
            gt_mask_polys = []
            gt_poly_lens = []
        for i, ann in enumerate(ann_info):  # ann_info (list[dict]): Annotation info of an image.
            if ann.get('ignore', False):  ##  ignore  ??? coco haven't 'ignore'
                continue
            x1, y1, w, h = ann['bbox']
            if ann['area'] <= 0 or w < 1 or h < 1:
                continue
            bbox = [x1, y1, x1 + w - 1, y1 + h - 1]  ## 變換?
            if ann['iscrowd']:
                gt_bboxes_ignore.append(bbox)
            else:
                gt_bboxes.append(bbox)
                gt_labels.append(self.cat2label[ann['category_id']])# 也就是把iscrowd=0的保存下來,他們之間的對應:bbox ,self.cat2label[ann['category_id']]?
            if with_mask:
                gt_masks.append(self.coco.annToMask(ann))
                mask_polys = [
                    p for p in ann['segmentation'] if len(p) >= 6
                ]  # valid polygons have >= 3 points (6 coordinates)
                poly_lens = [len(p) for p in mask_polys]
                gt_mask_polys.append(mask_polys)
                gt_poly_lens.extend(poly_lens)
        # deal with gt_bboxes
        if gt_bboxes:
            gt_bboxes = np.array(gt_bboxes, dtype=np.float32)# 轉換成numpy中array的格式,好切片處理
            gt_labels = np.array(gt_labels, dtype=np.int64)
        else:
            gt_bboxes = np.zeros((0, 4), dtype=np.float32)
            gt_labels = np.array([], dtype=np.int64)
        # deal with gt_bboxes_ignore
        if gt_bboxes_ignore:
            gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
        else:
            gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)  
        ann = dict(
            bboxes=gt_bboxes, labels=gt_labels, bboxes_ignore=gt_bboxes_ignore)   # bboxes_ignore iscrowd 字段的作用不是僅用於 segmentation的嗎?
        if with_mask:    # mask 處理
            ann['masks'] = gt_masks
            # poly format is not used in the current implementation
            ann['mask_polys'] = gt_mask_polys
            ann['poly_lens'] = gt_poly_lens
        return ann

mmdetection 系列推薦文章:

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章