引言
同样,数据集也是需要build模型的。类CocoDataset是CustomDataset类的子类,而CustomDataset是Dataset
的子类。(之前的创建模型,都是torch.nn.module的子类,数据集的创建就不是咯,注意一下)
关于CustomDataset的定义和其类方法的代码解读,可以看下面这篇文章:
类CocoDataset依然作为形参,添加到@DATASETS.register_module中,作用就是将其保存到注册表的module中。共有四个方法:
load_annotations()
:加载标注文件中的annotation字典,返回图片信息,比如:info{"filename":"284193,faa9000f2678b5e.jpg"}
。get_ann_info(self,idx)
:获得annotation的信息,其实是调用了_parse_ann_info()
;它的形参是指定的图片id,返回值是个字典:bboxes
,bboxes_ignore
,labels
,masks
,mask_polys
,poly_lens
._filter_imgs(self, min_size)
:过滤图片,去除没有annotation标注文件的图片,以及图片尺寸小于min-size的图片。_parse_ann_info(self, ann_info, with_mask=True)
:解析一张图片的annotation的信息,主要是bbox和mask信息,返回值为:bboxes
,bboxes_ignore
,labels
,masks
,mask_polys
,poly_lens
.(如果没用mask分支,就没用后面的三个返回值了)
在父类中custom.py有其初始化,还有另外的几个重要的函数,比如prepare_train_img()、prepare_test_img()等。因为其代码行数太长,所以在上面的链接里写一篇讲解。
以下是coco.py的代码,如有错误的地方,还请指出,后面博主也会继续修改,增加对各个代码段的理解。
coco.py代码注释
import numpy as np
from pycocotools.coco import COCO
from .custom import CustomDataset
from .registry import DATASETS
@DATASETS.register_module
class CocoDataset(CustomDataset):
""" coco api
self.coco = COCO(ann_file)
create class members
# 一维数组,值为对应原coco数据集的annotation/images/categories信息
coco.anns = anns
coco.imgs = imgs
coco.cats = cats
# 两个默认value为list的字典 比如:imgToAnns{"1":[ann1,ann2,ann3,....]},'1'为image_id = 1,ann1时其为 1 的annotation。
coco.imgToAnns = imgToAnns
coco.catToImgs = catToImgs
以上五个members,在实例化COCO时,被创建
"""
CLASSES = ('person', 'mask')
def load_annotations(self, ann_file):
self.coco = COCO(ann_file)
#函数 getImgIds()、getCatIds()、getAnnIds(),返回值为integer array of img/cat/ann ids,形参为过滤条件
self.cat_ids = self.coco.getCatIds() # [1,2] , integer array of cat ids
self.cat2label = { # dict
cat_id: i + 1 # cat_id : cat_id + 1 ??
for i, cat_id in enumerate(self.cat_ids) # enumerate 遍历 , i 从0 开始,是取其索引的意思。cat_id就是索引下的值。
}
self.img_ids = self.coco.getImgIds()
img_infos = []
for i in self.img_ids:
info = self.coco.loadImgs([i])[0]
info['filename'] = info['file_name']
img_infos.append(info)
return img_infos
# 一个例子:比如info{'file_name': '273278,e118d000ec53d5cd.jpg', 'height': 1365, 'width': 2048, 'id': 4370, 'filename': '273278,e118d000ec53d5cd.jpg'}
# 指定id,获得该图片的标注信息。{bboxes, bboxes_ignore,labels, masks, mask_polys, poly_lens}
def get_ann_info(self, idx):
img_id = self.img_infos[idx]['id']
ann_ids = self.coco.getAnnIds(imgIds=[img_id])
ann_info = self.coco.loadAnns(ann_ids) # loadAnns() return ann or [ann,ann,ann],这里时返回单个ann数据
return self._parse_ann_info(ann_info, self.with_mask) # 调用 _parse_ann_info(), return dict {bboxes, bboxes_ignore,labels, masks, mask_polys, poly_lens}
def _filter_imgs(self, min_size=32):
valid_inds = []
ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values()) #获得ann标注文件里的image_id的值
for i, img_info in enumerate(self.img_infos):
if self.img_ids[i] not in ids_with_ann: #去除无标注文件的图片
continue
if min(img_info['width'], img_info['height']) >= min_size: # 去除小图片,尺寸小于32的图片不要。
valid_inds.append(i)
return valid_inds
def _parse_ann_info(self, ann_info, with_mask=True):
"""Parse bbox and mask annotation.
Args:
ann_info (list[dict]): Annotation info of an image.
with_mask (bool): Whether to parse mask annotations.
Returns:
dict: A dict containing the following keys: bboxes, bboxes_ignore,
labels, masks, mask_polys, poly_lens.
"""
gt_bboxes = []
gt_labels = []
gt_bboxes_ignore = []
# Two formats are provided.
# 1. mask: a binary map of the same size of the image.
# 2. polys: each mask consists of one or several polys, each poly is a
# list of float.
if with_mask:
gt_masks = []
gt_mask_polys = []
gt_poly_lens = []
for i, ann in enumerate(ann_info): # ann_info (list[dict]): Annotation info of an image.
if ann.get('ignore', False): ## ignore ??? coco haven't 'ignore'
continue
x1, y1, w, h = ann['bbox']
if ann['area'] <= 0 or w < 1 or h < 1:
continue
bbox = [x1, y1, x1 + w - 1, y1 + h - 1] ## 变换?
if ann['iscrowd']:
gt_bboxes_ignore.append(bbox)
else:
gt_bboxes.append(bbox)
gt_labels.append(self.cat2label[ann['category_id']])# 也就是把iscrowd=0的保存下来,他们之间的对应:bbox ,self.cat2label[ann['category_id']]?
if with_mask:
gt_masks.append(self.coco.annToMask(ann))
mask_polys = [
p for p in ann['segmentation'] if len(p) >= 6
] # valid polygons have >= 3 points (6 coordinates)
poly_lens = [len(p) for p in mask_polys]
gt_mask_polys.append(mask_polys)
gt_poly_lens.extend(poly_lens)
# deal with gt_bboxes
if gt_bboxes:
gt_bboxes = np.array(gt_bboxes, dtype=np.float32)# 转换成numpy中array的格式,好切片处理
gt_labels = np.array(gt_labels, dtype=np.int64)
else:
gt_bboxes = np.zeros((0, 4), dtype=np.float32)
gt_labels = np.array([], dtype=np.int64)
# deal with gt_bboxes_ignore
if gt_bboxes_ignore:
gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
else:
gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
ann = dict(
bboxes=gt_bboxes, labels=gt_labels, bboxes_ignore=gt_bboxes_ignore) # bboxes_ignore iscrowd 字段的作用不是仅用于 segmentation的吗?
if with_mask: # mask 处理
ann['masks'] = gt_masks
# poly format is not used in the current implementation
ann['mask_polys'] = gt_mask_polys
ann['poly_lens'] = gt_poly_lens
return ann
mmdetection 系列推荐文章: