VOC數據集製作與分析

1. LabelImg標註文件製作VOC數據集

1.1 選取標註文件夾下的圖片與xml文件,並修改name(這一步其實有點多餘)

# -*- coding: UTF-8 -*-
import os 
import shutil

# return xml_list
def find_xml(xml_fold):
    xml_list = []
    for _,dirs,files in os.walk(xml_fold):
        for f in files:
            if os.path.splitext(f)[-1]==".xml":
                xml_list.append(f)
    return xml_list

# return im_list
def find_image(im_fold):
    im_list = []
    im_ext_list = [".png",".PNG",".jpg",".JPG",".bmp",".BMP",".jpeg",".JPEG"]
    for _,dirs,files in os.walk(im_fold):
        for f in files:
            if os.path.splitext(f)[-1] in im_ext_list:#'f' is an image?
                im_list.append(f)
    return im_list

def copy_xml_im(fold_path,xml_list,dst_fold):
    xml_not_with_ext_list = []
    for xml_with_ext in xml_list:
        xml_not_with_ext_list.append(os.path.splitext(xml_with_ext)[-2])
    # print(xml_not_with_ext_list)

    im_ext_list = [".png",".PNG",".jpg",".JPG",".bmp",".BMP",".jpeg",".JPEG"]
    for _,dirs,files in os.walk(fold_path):
        for f in files:
            # copy image
            if os.path.splitext(f)[-1] in im_ext_list:#'f' is an image?
                if os.path.splitext(f)[-2] in xml_not_with_ext_list:
                    shutil.copy(os.path.join(_,f),os.path.join(dst_fold,f))
            # copy xml
            if os.path.splitext(f)[-1] == ".xml":
                shutil.copy(os.path.join(_,f),os.path.join(dst_fold,f))
            
def rename_xml_im(src_fold,dst_fold):
    if os.path.exists(dst_fold)==False:
        os.makedirs(dst_fold)

    for _,dirs,files in os.walk(src_fold):
        new_name = 0
        for f in files:
            if os.path.splitext(f)[-1]==".xml":
                new_name = new_name + 1
                shutil.copy(os.path.join(_,f),os.path.join(dst_fold,str(new_name).zfill(5)+".xml"))
                shutil.copy(os.path.join(_,os.path.splitext(f)[-2]+".png"),os.path.join(dst_fold,str(new_name).zfill(5)+".png"))
                
def fix_xml_content(src_fold):
    for _, dirs, files in os.walk(src_fold):
        for f1 in files:
            if (os.path.splitext(f1)[1]=='.xml'):
                with open(os.path.join(src_fold,f1),'r') as f_r:
                    lines = f_r.readlines()
                with open(os.path.join(src_fold, f1), 'w') as f_w:
                    for line in lines:
                        if '<filename>' in line:
                            line='    <filename>'+os.path.splitext(f1)[0]+'.png</filename>\n'
                        if '<path>' in line:
                            continue
                        f_w.write(line)

if __name__=="__main__":
    fold_path = "/圖片與xml路徑"
    dst_fold = fold_path + "/temp"
    if os.path.exists(dst_fold)==False:
        os.makedirs(dst_fold)
    xml_list = find_xml(fold_path)
    im_list = find_image(fold_path)
    print("the number of xmlfile:",len(xml_list))
    print("the number of image:",len(im_list))
    copy_xml_im(fold_path,xml_list,dst_fold)
    rename_xml_im(dst_fold,fold_path+"/rename")
    fix_xml_content(fold_path+"/VOC2007")
    

1.2 將xml文件名隨機寫入train.txt/val.txt/test.txt

''' 
find random index of [0,9] in every group.
select 6 numbers as train set index
select 2 numbers as val set index
select 2 numbers as test set index

number of groups = (number of xmlfile) / 10
others will be divided into test set
'''
import os
import random

def rand_index_list(start=0,end=9,proportion=1):
    rand_index_list = []
    while len(rand_index_list)<proportion*10:   
        rand_index=random.randint(start,end)
        if rand_index not in rand_index_list:
            rand_index_list.append(rand_index)
    rand_index_list.sort()
    return rand_index_list

def rand_serial(train_val_proportion,train_proportion):
    rand_train_val_index_list=rand_index_list(proportion=train_val_proportion)
    rand_train_index_list = [rand_train_val_index_list[x] for x in rand_index_list(start=0,end=train_proportion*10,proportion=train_proportion)]
    return rand_train_val_index_list,rand_train_index_list

def find_rand_file(path_fold=None):
    rand_train_val_index_list,rand_train_index_list = [],[]
    xml_list = []
    for _,dirs,files in os.walk(path_fold):
        for f in files:
            if os.path.splitext(f)[-1]==".xml":
                xml_list.append(os.path.splitext(f)[-2])
    # print(xml_list)
    index_list_num = len(xml_list)/10

    for index in range(index_list_num):
        train_val_index_list,train_index_list = rand_serial(0.8,0.6)
        rand_train_val_index_list = rand_train_val_index_list+[x+10*index for x in train_val_index_list]
        rand_train_index_list = rand_train_index_list+[x+10*index for x in train_index_list]
    rand_val_index_list = [x for x in rand_train_val_index_list if x not in rand_train_index_list]
    # print("train_val_index:{}".format(rand_train_val_index_list))
    # print("train_index:{}".format(rand_train_index_list))
    # print("val_index:{}".format(rand_val_index_list))
    train_list = [xml_list[i] for i in rand_train_index_list]
    val_list = [xml_list[i] for i in rand_val_index_list]
    test_list = [x for x in xml_list if (x not in train_list) and (x not in val_list)]

    # print("train:{}".format(train_list))
    # print("val:{}".format(val_list))    
    # print("test:{}".format(test_list))
    return train_list,val_list,test_list

def write_file(voc_path):
    train_list,val_list,test_list=find_rand_file(voc_path+"/Annotations")
    with open(voc_path+"/ImageSets/Main/train.txt",'w') as f_w:
        for x in train_list:
            f_w.write(x)
            f_w.write("\n")
    with open(voc_path+"/ImageSets/Main/val.txt",'w') as f_w:
        for x in val_list:
            f_w.write(x)
            f_w.write("\n")
    with open(voc_path+"/ImageSets/Main/test.txt",'w') as f_w:
        for x in test_list:
            f_w.write(x)
            f_w.write("\n")

if __name__ == "__main__":
    write_file("VOC2007")

2. VOC數據集各類別面積大小分佈

2.1 說明

計算VOC數據集中各個類別(以我們自己的數據集爲例:'car',‘cottage','town house','apartment','person','bird nest','honeycomb')的數量,以及各個類別的面積從0-16,32,64,128,256,512,1024,……的數量。並且繪製對應的直方圖

import os
from PIL import Image
import xml.etree.ElementTree as ET
import matplotlib.pyplot as plt
import numpy as np
def calculateObjectArea(xml, objectname, im=None, outputPath=None):
    imWidth = 0
    imHeight = 0
    areatuple = []
    tree = ET.parse(xml)
    root = tree.getroot()
    # print(root)
    for child1 in root:
        # print(child1.tag, child1.attrib, child1.text)
        for child2 in child1:
            if child2.tag == 'width':
                imWidth = child2.text
                # print(imWidth)
            if child2.tag == 'height':
                imHeight = child2.text
                # print(imHeight)
            if objectname in child2.text:
                # print(child1)
                for child3 in child1.iter('bndbox'):
                    for child4 in child3:
                        # print('child4:', child4.tag, child4.text)
                        if child4.tag == 'xmin':
                            xmin = child4.text
                            # print(type(child4.text))
                        if child4.tag == 'xmax':
                            xmax = child4.text
                        if child4.tag == 'ymin':
                            ymin = child4.text
                        if child4.tag == 'ymax':
                            ymax = child4.text
                    # print((int(xmax)-int(xmin))*(int(ymax)-int(ymin)))
                    areatuple.append((int(xmax)-int(xmin))*(int(ymax)-int(ymin)))
    # print(areatuple)
    areatuple.sort()
    return areatuple

def drawHist(tup, objectname, width):
    tup.sort()
    x=tup
    # print('x =',x)
    # draw hist
    n, bins, patches = plt.hist(x, width, normed = False, facecolor = 'g')

    plt.xlabel('Object Area')
    plt.ylabel('Number')
    # add title
    plt.title(objectname)
    # add text
    # plt.text(60, .025, r'$\mu=100,\ \sigma=15$')
    # plt.axis([40, 160, 0, 0.03])
    plt.grid(True, linewidth = '1')
    # plt.show()

def searchxml(xmlpath, objectname):
    sumtup,lower16,lower32,lower64=[],[],[],[]
    lower128,lower256,lower384,lower512,lower640,lower768 = [],[],[],[],[],[]
    lower1024,lower1152,lower1280,lower1536,lower1792=[],[],[],[],[]
    lower2000,lower4000,lower6000=[],[],[]
    count16,count32,count64 = 0,0,0
    count128,count256 ,count384,count512,count640,count768 = 0,0,0,0,0,0
    count1024,count1152,count1280,count1536,count1792=0,0,0,0,0
    count2000,count4000,count6000 = 0,0,0
    for _, dirs, files in os.walk(xmlpath):
        for f1 in files:
            if os.path.splitext(f1)[1] == '.xml' and '.idea' not in os.path.join(_, f1):
                # print('file is ', f1)
                temptup = calculateObjectArea(os.path.join(_, f1), objectname)
                # print('temptup', temptup)
            sumtup = sumtup + temptup
    for x in sumtup:
        if x > 0 and x < 16:
            lower16.append(x)
            count16 = count16 + 1
        if x > 0 and x < 32:
            lower32.append(x)
            count32 = count32 + 1
        if x > 0 and x < 64:
            lower64.append(x)
            count64 = count64 + 1
        if x > 0 and x < 128:
            lower128.append(x)
            count128 = count128 + 1
        if x > 0 and x < 256:
            lower256.append(x)
            count256 = count256 + 1
        if x > 0 and x < 384:
            lower384.append(x)
            count384 = count384 + 1
        if x > 0 and x < 512:
            lower512.append(x)
            count512 = count512 + 1
        if x > 0 and x < 640:
            lower640.append(x)
            count640 = count640 + 1
        if x > 0 and x < 768:
            lower768.append(x)
            count768 = count768 + 1
        if x > 0 and x < 1024:
            lower1024.append(x)
            count1024 = count1024 + 1
        if x > 0 and x < 1152:
            lower1152.append(x)
            count1152 = count1152 + 1
        if x > 0 and x < 1280:
            lower1280.append(x)
            count1280 = count1280 + 1
        if x > 0 and x < 1536:
            lower1536.append(x)
            count1536 = count1536 + 1
        if x > 0 and x < 1792:
            lower1792.append(x)
            count1792 = count1792 + 1
        if x > 0 and x < 2000:
            lower2000.append(x)
            count2000 = count2000 + 1
        if x > 0 and x < 4000:
            lower4000.append(x)
            count4000 = count4000 + 1
        if x > 0 and x < 6000:
            lower6000.append(x)
            count6000 = count6000 + 1
    print(objectname+' total:', len(sumtup))
    drawHist(sumtup, objectname +'\'s number:'+str(len(sumtup)), 200)
    if count16:
        print('the number of '+objectname+' (\'area<16\' pixel):', count16)
        drawHist(lower16, objectname+'--the number of \'area<16\' pixel:'+str(count16), 30)
    if count32:
        print('the number of '+objectname+' (\'area<32\' pixel):', count32)
        drawHist(lower32, objectname+'--the number of \'area<32\' pixel:'+str(count32), 30)
    if count64:
        print('the number of '+objectname+' (\'area<64\' pixel):', count64)
        drawHist(lower64, objectname+'--the number of \'area<64\' pixel:'+str(count64), 30)
    if count128:
        print('the number of '+objectname+' (\'area<128\' pixel):', count128)
        drawHist(lower128, objectname+'--the number of \'area<128\' pixel:'+str(count128), 30)
    if count256:
        print('the number of '+objectname+' (\'area<256\' pixel):', count256)
        drawHist(lower256, objectname+'--the number of \'area<256\' pixel:'+str(count256), 30)
    if count384:
        print('the number of '+objectname+' (\'area<384\' pixel):', count384)
        drawHist(lower384, objectname+'--the number of \'area<384\' pixel:'+str(count384), 30)
    if count512:
        print('the number of '+objectname+' (\'area<512\' pixel):', count512)
        drawHist(lower512, objectname+'--the number of \'area<512\' pixel:'+str(count512), 30)
    if count640:
        print('the number of '+objectname+' (\'area<640\' pixel):', count640)
        drawHist(lower640, objectname+'--the number of \'area<640\' pixel:'+str(count640), 30)
    if count768:
        print('the number of '+objectname+' (\'area<768\' pixel):', count768)
        drawHist(lower768, objectname+'--the number of \'area<768\' pixel:'+str(count768), 50)
    if count1024:
        print('the number of '+objectname+' (\'area<1024\' pixel):', count1024)
        drawHist(lower1024, objectname+'--the number of \'area<1024\' pixel:'+str(count1024), 50)
    if count1152:
        print('the number of '+objectname+' (\'area<1152\' pixel):', count1152)
        drawHist(lower1152, objectname + '--the number of \'area<1152\' pixel:' + str(count1152), 60)
    if count1280:
        print('the number of '+objectname+' (\'area<1280\' pixel):', count1280)
        drawHist(lower1280, objectname + '--the number of \'area<1280\' pixel:' + str(count1280), 80)
    if count1536:
        print('the number of '+objectname+' (\'area<1536\' pixel):', count1536)
        drawHist(lower1536, objectname + '--the number of \'area<1536\' pixel:' + str(count1536), 80)
    if count1792:
        print('the number of '+objectname+' (\'area<1792\' pixel):', count1792)
        drawHist(lower1792, objectname + '--the number of \'area<1792\' pixel:' + str(count1792), 100)
    if count2000:
        print('the number of '+objectname+' (\'area<2000\' pixel):', count2000)
        drawHist(lower2000, objectname + '--the number of \'area<2000\' pixel:' + str(count2000), 100)
    if count4000:
        print('the number of '+objectname+' (\'area<4000\' pixel):', count4000)
        drawHist(lower4000, objectname + '--the number of \'area<4000\' pixel:' + str(count4000), 100)
    if count6000:
        print('the number of '+objectname+' (\'area<6000\' pixel):', count6000)
        drawHist(lower6000, objectname + '--the number of \'area<6000\' pixel:' + str(count6000), 100)

if __name__=='__main__':
    xmlpath='./Annotations'
    objectnameList = ['car','cottage','town house','apartment','person','bird nest','honeycomb']
    for objectname in objectnameList:
        searchxml(xmlpath, objectname)

重新上一個稍微好一點的代碼:

import os 
import xml.etree.ElementTree as ET

def object_name_with_path_list(path_fold):
    xml_with_path_list = []
    for _,dirs,files in os.walk(path_fold):
        for f in files:
            if os.path.splitext(f)[-1]==".xml":
                xml_with_path_list.append(os.path.join(_,f))
    return xml_with_path_list

def get_object_area(path_fold):
    xml_with_path_list = object_name_with_path_list(path_fold)
    # print("xml_with_path:{}".format(xml_with_path_list))
    area_list = []
    for item in xml_with_path_list:
        tree = ET.parse(item)
        root = tree.getroot()
        for child in root:
            object_name_list = []
            object_positon_list = []
            if child.tag=="object":
                for child1 in child:
                    if child1.tag == "name":
                        object_name = child1.text
                        object_name_list.append(object_name)
                    if child1.tag == "bndbox":
                        x_min = [child2.text for child2 in child1 if child2.tag=="xmin"]
                        x_max = [child2.text for child2 in child1 if child2.tag=="xmax"]
                        y_min = [child2.text for child2 in child1 if child2.tag=="ymin"]
                        y_max = [child2.text for child2 in child1 if child2.tag=="ymax"]
                        object_positon=x_min+y_min+x_max+y_max
                        object_positon_list.append(object_positon)
                        object_area = (int(x_max[0])-int(x_min[0]))*(int(y_max[0])-int(y_min[0]))
                        area_list.append(object_area)
                # print("object_list:{0},{1}".format(object_name_list,object_positon_list))
    area_lower_64_list = [x for x in area_list if x < 64]
    area_lower_128_list = [x for x in area_list if x < 128]
    area_lower_256_list = [x for x in area_list if x < 256]
    area_lower_512_list = [x for x in area_list if x < 512]
    print("total numbers:{},lower_64:{},lower_128:{},lower_256:{},lower_512:{}"\
    .format(len(area_list),len(area_lower_64_list),len(area_lower_128_list),len(area_lower_256_list),len(area_lower_512_list)))

if __name__=="__main__":
    get_object_area("VOC2007/Annotations/")

2.2 結果

直接上圖

3.裁剪labelImg標註的所有對象

# -*- coding:utf-8 -*-
'''
程序放在VOC2007文件夾下
'''
import os 
import xml.etree.ElementTree as ET
from PIL import Image
import time

def mkdirs(outputPath):
    if os.path.exists(outputPath) == False:
        os.makedirs(outputPath)#make directory and subdirectory

# crop all object of VOC2007 datasets
def crop_object(xml_path,im_path):
    name_list = [os.path.splitext(f)[-2] for _,dirs,files in os.walk(xml_path) for f in files if os.path.splitext(f)[-1]==".xml"]
    for item in name_list:
        tree = ET.parse(os.path.join(xml_path,item+".xml"))
        root = tree.getroot()
        object_name_list = [child.find('name').text for child in root if child.tag=="object"]
        object_position_list = [(int(child.find('bndbox').find("xmin").text),
                                int(child.find('bndbox').find("ymin").text),
                                int(child.find('bndbox').find("xmax").text),
                                int(child.find('bndbox').find("ymax").text))
                                for child in root if child.tag=="object"]
        im = Image.open(os.path.join(im_path, item+".jpg"))
        im_after_crop = [im.crop(position) for position in object_position_list]
        for crop_im_index in range(len(im_after_crop)):
            object_name_path = object_name_list[crop_im_index]
            now_time_stamp = int(time.time()) 
            mkdirs(os.path.join("objects",object_name_path))
            im_after_crop[crop_im_index].save(os.path.join("objects",object_name_path,str(now_time_stamp)+str(object_position_list[crop_im_index][0])+".jpg"))

if __name__ == '__main__':
    crop_object("Annotations/","JPEGImages/")

4.選擇某一類對象object,刪除其餘所有對象

# -*- coding:utf-8 -*-
import os 
import xml.etree.ElementTree as ET
import time

def mkdirs(outputPath):
    if os.path.exists(outputPath) == False:
        os.makedirs(outputPath)#make directory and subdirectory

# choose specific class(choose_class),remove all other classes.
def choose_object(xml_path,im_path,dst_path,choose_class):
    name_list = [os.path.splitext(f)[-2] for _,dirs,files in os.walk(xml_path) for f in files if os.path.splitext(f)[-1]==".xml"]
    for item in name_list:
        tree = ET.parse(os.path.join(xml_path,item+".xml"))
        root = tree.getroot()
        object_item_list = [child for child in root if child.tag=="object" and child.find('name').text != choose_class]
        object_cross_list = [child for child in root if child.tag=="object" and child.find('name').text == choose_class]
        for object_item in object_item_list:
            root.remove(object_item)
        mkdirs(dst_path)
        if len(object_cross_list)>0:
            tree.write(os.path.join(dst_path,item+".xml"))

if __name__ == '__main__':
    choose_object("Annotations/","JPEGImages/","dst_xml_path/","car")

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章