Tensorflow2.0 YOLO篇之提取xml文件信息

Tensorflow2.0 YOLO篇之提取xml文件信息



數據集介紹

數據集下載地址:

鏈接:https://pan.baidu.com/s/1ZP9H2ym3Vp4Sda1mNiv9Pw 
提取碼:5okb 
複製這段內容後打開百度網盤手機App,操作更方便哦

這次選擇的數據集是甜菜(sugarbeet)和雜草(weed)的數據集
在這裏插入圖片描述
在數據集的xml文件中包含了圖片中物體的位置形狀(x,y,w,h)和label
其中的一個xml文件

<annotation>
	<folder>train</folder>
	<filename>X2-10-1.png</filename>
	<path /><source>
		<database>Unknown</database>
	</source>
	<size>
		<width>512</width>
		<height>512</height>
		<depth>3</depth>
	</size>
	<segmented>0</segmented>
	<object>
		<name>weed</name>
		<pose>Unspecified</pose>
		<truncated>0</truncated>
		<difficult>0</difficult>
		<bndbox>
			<xmin>71</xmin>
			<ymin>265</ymin>
			<xmax>115</xmax>
			<ymax>278</ymax>
		</bndbox>
	</object>
	
	......

	<object>
		<name>sugarbeet</name>
		<pose>Unspecified</pose>
		<truncated>0</truncated>
		<difficult>0</difficult>
		<bndbox>
			<xmin>322</xmin>
			<ymin>266</ymin>
			<xmax>363</xmax>
			<ymax>294</ymax>
		</bndbox>
	</object>
</annotation>

現在我們要做的工作就是將這些數據儲存到numpy數組中去,代碼中我儘可能的寫了註釋,書寫這個的是否選擇了vscode作爲編譯工具,以爲vscode對於jupyter的支持較好,可以在編寫的過程中更加方便的查看每一步的運行結果

同時在編寫這一步的時候需要注意的一個點就是每個圖片中的物體個數可能不一樣,這樣我們的boxes的個數就有問題。因爲每個圖片中的框信息都沒有超過5個(上圖除外那是我自己畫的),所以我們每一張圖片都涉及有五個空,不足的就用0來填充

#%%
import tensorflow as tf
import os,glob
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import numpy as np
from tensorflow import keras
# set random seed
tf.random.set_seed(2233)
np.random.seed(2233)

# %%
print(tf.__version__)
print(tf.text.is_gpu_available())



# %%
import xml.etree.ElementTree as ET

def parse_annotation(img_dir,ann_dir,labels):
    # parse annotation and save is into numpy array
    # img_dir: image path
    # ann_dir: annotation xml file path
    # labels: ('sugarweet','weed')
    imgs_info =[]
    # for each annotation xml file
    max_boxes = 0
    for ann in os.listdir(ann_dir):
        tree = ET.parse(os.path.join(ann_dir,ann))

        img_info = dict()
        img_info['object'] = []
        boxes_counter = 0
        for elem in tree.iter():
            if 'filename' in elem.tag:
                img_info['filename'] = os.path.join(img_dir,elem.text)
            if 'width' in elem.tag:
                img_info['width'] = int(elem.text)
                assert img_info['width'] == 512
            if 'height' in elem.tag:
                img_info['height'] = int(elem.text)
                assert img_info['width'] == 512
            if 'object' in elem.tag or 'part' in elem.tag:
                # x1-y1-x2-y2-label
                object_info =  [0,0,0,0,0]
                boxes_counter += 1
                for attr in list(elem):
                    # add image info into object_info
                    if 'name' in attr.tag:
                        label = labels.index(attr.text) + 1
                        object_info[4] = label
                    if 'bndbox' in attr.tag:
                        for pos in list(attr):
                            if 'xmin' in pos.tag:
                                object_info[0] = int(pos.text)
                            if 'ymin' in pos.tag:
                                object_info[1] = int(pos.text)
                            if 'xmax' in pos.tag:
                                object_info[2] = int(pos.text)
                            if 'ymax' in pos.tag:
                                object_info[3] = int(pos.text)
                img_info['object'].append(object_info)
        imgs_info.append(img_info) # filename,w/h/box_info
        # (N,5) = (max_objects_num,5) 5 is x-y-w-h-label
        if boxes_counter > max_boxes:
            max_boxes = boxes_counter
    # the maximum boxes number is max_boxes
    # [b,max_things,5]
    boxes = np.zeros([len(imgs_info),max_boxes,5])
    imgs = [] # filename last
    for i,img_info in enumerate(imgs_info):
        # [N,5] N: boxes number
        img_boxes = np.array(img_info['object'])
        # overwrite the N boxes info
        boxes[i,:img_boxes.shape[0]] = img_boxes
        imgs.append(img_info['filename'])
        print(img_info['filename'],boxes[i,:5])
    # imgs: list of image path
    # boxes:[b,40,5]
    return imgs,boxes

# %%
obj_names = ('sugarbeet','weed')
imgs,boxes = parse_annotation('data/train/image','data/train/annotation',obj_names)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章