有時候我們會抽取一些公開數據集的某些類別數據作爲自己補充訓練數據。抽取VOC2012數據集指定類別之前講到過,參考:Yolov3 行人檢測 – 使用Yolov3訓練從VOC2012抽取出來的行人數據
本文介紹抽取COCO數據集的指定類別並將標籤保存爲XML
格式。代碼是參考網上的,不過對代碼進行了整理和註釋,直接看代碼!
from pycocotools.coco import COCO
import os
import shutil
from tqdm import tqdm
import skimage.io as io
import matplotlib.pyplot as plt
import cv2
from PIL import Image, ImageDraw
headstr = """\
<annotation>
<folder>VOC</folder>
<filename>%s</filename>
<source>
<database>My Database</database>
<annotation>COCO</annotation>
<image>flickr</image>
<flickrid>NULL</flickrid>
</source>
<owner>
<flickrid>NULL</flickrid>
<name>company</name>
</owner>
<size>
<width>%d</width>
<height>%d</height>
<depth>%d</depth>
</size>
<segmented>0</segmented>
"""
objstr = """\
<object>
<name>%s</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>%d</xmin>
<ymin>%d</ymin>
<xmax>%d</xmax>
<ymax>%d</ymax>
</bndbox>
</object>
"""
tailstr = '''\
</annotation>
'''
# 創建文件夾,如果存在,先刪除在建立,否則直接建立文件夾
def mkr(path):
if os.path.exists(path):
shutil.rmtree(path)
os.mkdir(path)
else:
os.mkdir(path)
def id2name(coco):
classes = dict()
for cls in coco.dataset['categories']:
classes[cls['id']] = cls['name']
return classes
# 在目標文件夾下面寫入xml文件
def write_xml(anno_path, head, objs, tail):
f = open(anno_path, "w")
# 寫頭
f.write(head)
# 寫object
for obj in objs:
f.write(objstr % (obj[0], obj[1], obj[2], obj[3], obj[4]))
# 寫尾
f.write(tail)
# 保存轉後好的xml和對應的jpg到指定文件夾
def save_annotations_and_imgs(coco, dataset, filename, objs):
anno_path = anno_dir + filename[:-3] + 'xml'
img_path = dataDir + dataset + '/' + filename
print(img_path)
dst_imgpath = img_dir + filename
img = cv2.imread(img_path)
# 拷貝xml對應的jpg到指定文件夾
shutil.copy(img_path, dst_imgpath)
head = headstr % (filename, img.shape[1], img.shape[0], img.shape[2])
tail = tailstr
write_xml(anno_path, head, objs, tail)
def showimg(coco, dataset, img, classes, cls_id, show=True):
global dataDir
I = Image.open('%s/%s/%s' % (dataDir, dataset, img['file_name']))
# 通過id,得到註釋的信息
annIds = coco.getAnnIds(imgIds=img['id'], catIds=cls_id, iscrowd=None)
# print(annIds)
anns = coco.loadAnns(annIds)
# print(anns)
# coco.showAnns(anns)
objs = []
for ann in anns:
class_name = classes[ann['category_id']]
if class_name in classes_names:
print(class_name)
if 'bbox' in ann:
bbox = ann['bbox']
xmin = int(bbox[0])
ymin = int(bbox[1])
xmax = int(bbox[2] + bbox[0])
ymax = int(bbox[3] + bbox[1])
obj = [class_name, xmin, ymin, xmax, ymax]
objs.append(obj)
draw = ImageDraw.Draw(I)
draw.rectangle([xmin, ymin, xmax, ymax])
if show:
plt.figure()
plt.axis('off')
plt.imshow(I)
plt.show()
return objs
if __name__ == '__main__':
# 提取類的保存路徑,路徑下面是image和Annotation文件夾
savepath = "./coco_person/"
img_dir = savepath + 'images/'
anno_dir = savepath + 'XMLAnnotations/'
# 被提取的數據集列表,依情況而定,可爲['train2014', 'train2017']
datasets_list = ['train2017']
print(datasets_list)
# 提取的類別,coco有80類,這裏寫要提取類的名字,以person爲例
classes_names = ['person']
# 原始數據集路徑,路徑下是原始數據集的images和Annotations文件夾
dataDir = './COCO/'
# 建立提取類的保存文件夾
mkr(img_dir)
mkr(anno_dir)
for dataset in datasets_list:
# 原來數據集的json文件路徑
annFile = '{}/annotations/instances_{}.json'.format(dataDir, dataset)
# COCO API for initializing annotated data
coco = COCO(annFile)
# show all classes in coco
classes = id2name(coco)
print(classes)
# [1, 2, 3, 4, 6, 8]
classes_ids = coco.getCatIds(catNms=classes_names)
print(classes_ids)
for cls in classes_names:
# 得到提取類的ID
cls_id = coco.getCatIds(catNms=[cls])
img_ids = coco.getImgIds(catIds=cls_id)
print(cls, len(img_ids))
for imgId in tqdm(img_ids):
img = coco.loadImgs(imgId)[0]
filename = img['file_name']
print(filename)
objs = showimg(coco, dataset, img, classes, classes_ids, show=False)
print(objs)
save_annotations_and_imgs(coco, dataset, filename, objs)