mmdetection訓練VOC數據

訓練自己數據 VOC格式

目錄

1、加入類別name定義

2、定義dataset

3、修改config.py


1、加入類別name定義

# mmdetection\mmdet\core\evaluation\class_names.py

def voc_trash_classes():
    return [
        '一次性快餐盒', '書籍紙張', '充電寶', '剩飯剩菜', '包', '垃圾桶',
        '塑料器皿', '塑料玩具', '塑料衣架', '大骨頭', '乾電池', '快遞紙袋', '插頭電線',
        '舊衣服', '易拉罐', '枕頭', '果皮果肉', '毛絨玩具', '污損塑料', '污損用紙', '洗護用品',
        '菸蒂', '牙籤', '玻璃器皿', '砧板', '筷子', '紙盒紙箱', '花盆', '茶葉渣', '菜幫菜葉',
        '蛋殼', '調料瓶', '軟膏', '過期藥物', '酒瓶', '金屬廚具', '金屬器皿', '金屬食品罐', '鍋',
        '陶瓷器皿', '鞋', '食用油桶', '飲料瓶', '魚骨'
    ]

dataset_aliases = {
    'voc': ['voc', 'pascal_voc', 'voc07', 'voc12'],
    'imagenet_det': ['det', 'imagenet_det', 'ilsvrc_det'],
    'imagenet_vid': ['vid', 'imagenet_vid', 'ilsvrc_vid'],
    'coco': ['coco', 'mscoco', 'ms_coco'],
    'wider_face': ['WIDERFaceDataset', 'wider_face', 'WDIERFace'],
    'cityscapes': ['cityscapes'],
    'voc_trash':['voc07_trash'] # 添加
}

2、定義dataset

# mmdetection\mmdet\datasets\voc_trash.py
# 根據 voc.py 修改

@DATASETS.register_module()
class VOCDataset_trash(XMLDataset):

    CLASSES =('一次性快餐盒','書籍紙張','充電寶','剩飯剩菜','包','垃圾桶',
              '塑料器皿','塑料玩具','塑料衣架','大骨頭','乾電池','快遞紙袋','插頭電線',
              '舊衣服','易拉罐','枕頭','果皮果肉','毛絨玩具','污損塑料','污損用紙','洗護用品',
              '菸蒂','牙籤','玻璃器皿','砧板','筷子','紙盒紙箱','花盆','茶葉渣','菜幫菜葉',
              '蛋殼','調料瓶','軟膏','過期藥物','酒瓶','金屬廚具','金屬器皿','金屬食品罐','鍋',
              '陶瓷器皿','鞋','食用油桶','飲料瓶','魚骨')
# ...
        if metric == 'mAP':
            assert isinstance(iou_thr, float)
            if self.year == 2007:
                ds_name = 'voc07_trash' # 上一步dataset_aliases中添加的
            else:
                ds_name = self.dataset.CLASSES
# mmdetection\mmdet\datasets\__init__.py
# ...
__all__ = [
    'CustomDataset', 'XMLDataset', 'CocoDataset', 'VOCDataset',
    'CityscapesDataset', 'GroupSampler', 'DistributedGroupSampler',
    'DistributedSampler', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',
    'ClassBalancedDataset', 'WIDERFaceDataset', 'DATASETS', 'PIPELINES',
    'build_dataset',
    'VOCDataset_trash'  # 添加類name
]

3、修改config.py

_base_ = [  # path might need to be changed
    './_base_/models/faster_rcnn_r50_fpn.py',
    './_base_/datasets/voc0712.py',
    './_base_/schedules/schedule_1x.py',
    './_base_/default_runtime.py'
]

model = dict(
    type='FasterRCNN',
    pretrained=None,
    # pretrained='torchvision://resnet50',
    roi_head=dict(
        bbox_head=dict(
            type='Shared2FCBBoxHead',
            in_channels=256,
            fc_out_channels=1024,
            roi_feat_size=7,
            # num_classes=80, # for coco
            num_classes=44,   # 修改類別數,不含背景
            bbox_coder=dict(
                type='DeltaXYWHBBoxCoder',
                target_means=[0., 0., 0., 0.],
                target_stds=[0.1, 0.1, 0.2, 0.2]),
            reg_class_agnostic=False,
            loss_cls=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
            loss_bbox=dict(type='L1Loss', loss_weight=1.0)
        )
    )
)

# dataset settings
dataset_type = 'VOCDataset_trash' # 自定義dataset類型
data_root = 'data/trainval/' # data_root = 'data/VOCdevkit/'

# load pretrained model with only weights
load_from = './checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'

optimizer = dict(type='SGD', lr=0.005, momentum=0.9, weight_decay=0.0001)

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章