mmdetection训练VOC数据

训练自己数据 VOC格式

目录

1、加入类别name定义

2、定义dataset

3、修改config.py


1、加入类别name定义

# mmdetection\mmdet\core\evaluation\class_names.py

def voc_trash_classes():
    return [
        '一次性快餐盒', '书籍纸张', '充电宝', '剩饭剩菜', '包', '垃圾桶',
        '塑料器皿', '塑料玩具', '塑料衣架', '大骨头', '干电池', '快递纸袋', '插头电线',
        '旧衣服', '易拉罐', '枕头', '果皮果肉', '毛绒玩具', '污损塑料', '污损用纸', '洗护用品',
        '烟蒂', '牙签', '玻璃器皿', '砧板', '筷子', '纸盒纸箱', '花盆', '茶叶渣', '菜帮菜叶',
        '蛋壳', '调料瓶', '软膏', '过期药物', '酒瓶', '金属厨具', '金属器皿', '金属食品罐', '锅',
        '陶瓷器皿', '鞋', '食用油桶', '饮料瓶', '鱼骨'
    ]

dataset_aliases = {
    'voc': ['voc', 'pascal_voc', 'voc07', 'voc12'],
    'imagenet_det': ['det', 'imagenet_det', 'ilsvrc_det'],
    'imagenet_vid': ['vid', 'imagenet_vid', 'ilsvrc_vid'],
    'coco': ['coco', 'mscoco', 'ms_coco'],
    'wider_face': ['WIDERFaceDataset', 'wider_face', 'WDIERFace'],
    'cityscapes': ['cityscapes'],
    'voc_trash':['voc07_trash'] # 添加
}

2、定义dataset

# mmdetection\mmdet\datasets\voc_trash.py
# 根据 voc.py 修改

@DATASETS.register_module()
class VOCDataset_trash(XMLDataset):

    CLASSES =('一次性快餐盒','书籍纸张','充电宝','剩饭剩菜','包','垃圾桶',
              '塑料器皿','塑料玩具','塑料衣架','大骨头','干电池','快递纸袋','插头电线',
              '旧衣服','易拉罐','枕头','果皮果肉','毛绒玩具','污损塑料','污损用纸','洗护用品',
              '烟蒂','牙签','玻璃器皿','砧板','筷子','纸盒纸箱','花盆','茶叶渣','菜帮菜叶',
              '蛋壳','调料瓶','软膏','过期药物','酒瓶','金属厨具','金属器皿','金属食品罐','锅',
              '陶瓷器皿','鞋','食用油桶','饮料瓶','鱼骨')
# ...
        if metric == 'mAP':
            assert isinstance(iou_thr, float)
            if self.year == 2007:
                ds_name = 'voc07_trash' # 上一步dataset_aliases中添加的
            else:
                ds_name = self.dataset.CLASSES
# mmdetection\mmdet\datasets\__init__.py
# ...
__all__ = [
    'CustomDataset', 'XMLDataset', 'CocoDataset', 'VOCDataset',
    'CityscapesDataset', 'GroupSampler', 'DistributedGroupSampler',
    'DistributedSampler', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',
    'ClassBalancedDataset', 'WIDERFaceDataset', 'DATASETS', 'PIPELINES',
    'build_dataset',
    'VOCDataset_trash'  # 添加类name
]

3、修改config.py

_base_ = [  # path might need to be changed
    './_base_/models/faster_rcnn_r50_fpn.py',
    './_base_/datasets/voc0712.py',
    './_base_/schedules/schedule_1x.py',
    './_base_/default_runtime.py'
]

model = dict(
    type='FasterRCNN',
    pretrained=None,
    # pretrained='torchvision://resnet50',
    roi_head=dict(
        bbox_head=dict(
            type='Shared2FCBBoxHead',
            in_channels=256,
            fc_out_channels=1024,
            roi_feat_size=7,
            # num_classes=80, # for coco
            num_classes=44,   # 修改类别数,不含背景
            bbox_coder=dict(
                type='DeltaXYWHBBoxCoder',
                target_means=[0., 0., 0., 0.],
                target_stds=[0.1, 0.1, 0.2, 0.2]),
            reg_class_agnostic=False,
            loss_cls=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
            loss_bbox=dict(type='L1Loss', loss_weight=1.0)
        )
    )
)

# dataset settings
dataset_type = 'VOCDataset_trash' # 自定义dataset类型
data_root = 'data/trainval/' # data_root = 'data/VOCdevkit/'

# load pretrained model with only weights
load_from = './checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'

optimizer = dict(type='SGD', lr=0.005, momentum=0.9, weight_decay=0.0001)

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章