mask rcnn bencmark pytorch自定義數據集的方法

前言

參考代碼: mask rcnn benchmark

數據集來源:津南數字製造算法挑戰賽【賽場二】初賽

這個代碼不能直接運行,僅僅提供參考,本人也僅僅是接觸檢測不到一個禮拜,如果有什麼疑問歡迎在討論區交流。

1、數據解讀

數據集訓練train_no_poly.json的格式,類coco風格

import json
with open('../train_no_poly.json', 'r') as f:
    data = json.load(f)

print(data.keys())
>>> dict_keys(['info', 'licenses', 'categories', 'images', 'annotations'])

print(data['info'])
>>> {'description': 'XRAY Instance Dataset ', 'url': '', 'version': '0.2.0', 'year': 2019, 'contributor': 'qianxiao', 'date_created': '2019-03-04 08:52:50.852455'}

print(data['licenses'])
>>> [{'id': 1, 'name': 'Attribution-NonCommercial-ShareAlike License', 'url': ''}]

print(data['categories'])
>>> [{'id': 1, 'name': '鐵殼打火機', 'supercategory': 'restricted_obj'}, {'id': 2, 'name': '黑釘打火機', 'supercategory': 'restricted_obj'}, {'id': 3, 'name': '刀具', 'supercategory': 'restricted_obj'}, {'id': 4, 'name': '電源和電池', 'supercategory': 'restricted_obj'}, {'id': 5, 'name': '剪刀', 'supercategory': 'restricted_obj'}]

print(data['images'][0])
>>> {'coco_url': '', 'data_captured': '', 'file_name': '190119_184244_00166940.jpg', 'flickr_url': '', 'id': 0, 'height': 391, 'width': 680, 'license': 1}

print(data['annotations'][0])  # 注意,一個圖像可能有多個bbox,json中把每個bbox分別存放在不同的字典中
>>> {'id': 1, 'image_id': 0, 'category_id': 3, 'iscrowd': 0, 'segmentation': [], 'area': [], 'bbox': [88, 253, 118, 42], 'minAreaRect': [[88, 298], [86, 256], [203, 249], [206, 291]]}

2、拷貝數據集到根目錄的datasets下(和demo同級目錄)如

maskrcnn-benchmark/datasets/jinnan/jinnan2_round1_train_20190305

3、修改paths_catalog.py

路徑爲maskrcnn-benchmark/maskrcnn_benchmark/config/paths_catalog.py

a、在paths_catalog中的DATASETS字典中添加你需要的路徑,如

"jinnan_train": {
"img_dir": "jinnan2_round1_train_20190305",  # rgb格式文件路徑
"ann_file": "jinnan2_round1_train_20190305/train_no_poly.json"
},

注意:自定義數據集的話,img_dirann_file會作爲形參傳到你自己創建的MyDataset類裏面

b、修改paths_catalog中部靜態函數get(name)方法

添加一個if else,把你創建的數據集相關內容放進去,如

elif "jinnan" in name:  # name對應yaml文件傳過來的數據集名字
    data_dir = DatasetCatalog.DATA_DIR
    attrs = DatasetCatalog.DATASETS[name]
    args = dict(
        root=os.path.join(data_dir, attrs["img_dir"]),  # img_dir就是a步驟裏面的內容
        ann_file=os.path.join(data_dir, attrs["ann_file"]),  # ann_file就是a步驟裏面的內容
    )
    return dict(
        factory="MyDataset",  # 這個MyDataset對應
        args=args,
    )

上面參數解釋(主要是MyDataset):

  1. 這個MyDataset就是你自己建的那個類,返回值是image, boxlist, idx,具體實現參考git官網(很容易)

  2. 比如我實現好了MyDataset類,然後這個py文件取名爲jinnan.py

  3. 然後放在maskrcnn-benchmark/maskrcnn_benchmark/data/datasets路徑下

  4. 接着配置那個目錄裏面的__init__.py文件,第四行和all最後一個元素是自己加的

from .coco import COCODataset
from .voc import PascalVOCDataset
from .concat_dataset import ConcatDataset
from .jinnan import MyDataset

all = ["COCODataset", "ConcatDataset", "PascalVOCDataset", "MyDataset"]
  1. 注意,實現MyDataset要實現__len____getitem__get_img_info,還有__init__,其中__init__會得到第一個步驟傳來的attrs__init__的一個參數參考:
def __init__(self,ann_file=None, root=None, remove_images_without_annotations=None, transforms=None)

不知參數是什麼意思得去看maskrcnn-benchmark/maskrcnn_benchmark/data/build.py

4、修改yaml文件

主要是修改數據load部分

MODEL:
  MASK_ON: False
DATASETS:
  TRAIN: ("jinnan_train", "jinnan_val")
  TEST: ("jinnan_test",)

上面三個值都是自己設的,其實有用的就jinnan_train,當然首先重要的是要把MASK_ON關閉。

5、 我自己寫的數據加載的凌亂的參考

注意只是參考,根據自己的不同需求返回image, boxlist, idx就行

放置路徑:maskrcnn-benchmark/maskrcnn_benchmark/data/datasets/jinnan.py

from maskrcnn_benchmark.structures.bounding_box import BoxList
from PIL import Image
import os
import json
import torch

class MyDataset(object):
    def __init__(self,ann_file=None, root=None, remove_images_without_annotations=None, transforms=None):
        # as you would do normally

        self.transforms = transforms

        self.train_path = root
        with open(ann_file, 'r') as f:
            self.data = json.load(f)

        self.idxs = list(range(len(self.data['images'])))  # 看要訓練的圖像有多少張,把id用個列表存儲方便隨機
        self.bbox_label = {}
        for anno in self.data['annotations']:
            bbox = anno['bbox']
            bbox[2] += bbox[0]
            bbox[3] += bbox[1]
            cate = anno['category_id']
            image_id = anno['image_id']
            if not image_id in self.bbox_label:
                self.bbox_label[image_id] = [[bbox], [cate]]
            else:
                self.bbox_label[image_id][0].append(bbox)
                self.bbox_label[image_id][1].append(cate)

    def __getitem__(self, idx):
        # load the image as a PIL Image
        idx = self.idxs[idx % len(self.data['images'])]
        # if idx not in self.bbox_label:  # 210, 262, 690, 855 have no bbox
        #    idx += 1
        path = self.data['images'][idx]['file_name']

        folder = 'restricted' if idx < 981 else 'normal'

        image = Image.open(os.path.join(self.train_path, folder, path)).convert('RGB')
        # load the bounding boxes as a list of list of boxes
        # in this case, for illustrative purposes, we use
        # x1, y1, x2, y2 order.
        # boxes = [[0, 0, 10, 10], [10, 20, 50, 50]]
        boxes = self.bbox_label[idx][0]
        category = self.bbox_label[idx][-1]

        # and labels
        labels = torch.tensor(category)

        # create a BoxList from the boxes
        boxlist = BoxList(boxes, image.size, mode="xyxy")
        # add the labels to the boxlist
        boxlist.add_field("labels", labels)

        if self.transforms:
            image, boxlist = self.transforms(image, boxlist)

        # return the image, the boxlist and the idx in your dataset
        return image, boxlist, idx
    def __len__(self):
        return len(self.data['images'])

    def get_img_info(self, idx):
        idx = self.idxs[idx % len(self.data['images'])]
        height = self.data['images'][idx]['height']
        width = self.data['images'][idx]['width']
        # get img_height and img_width. This is used if
        # we want to split the batches according to the aspect ratio
        # of the image, as it can be more efficient than loading the
        # image from disk
        return {"height": height, "width": width}

其他

transform maskrcnn-benchmark/maskrcnn_benchmark/data/build.py

在yaml裏面把weight改成自己權重的路徑,精確到文件

_C.MODEL.ROI_BOX_HEAD.NUM_CLASSES = 81改成6

我這裏把category_id設成了image_id

copy_if failed to synchronize: device-side assert triggered

https://github.com/facebookresearch/maskrcnn-benchmark/issues/450

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章