Tensorflow2.0 YOLO篇之提取xml文件信息
數據集介紹
數據集下載地址:
鏈接:https://pan.baidu.com/s/1ZP9H2ym3Vp4Sda1mNiv9Pw
提取碼:5okb
複製這段內容後打開百度網盤手機App,操作更方便哦
這次選擇的數據集是甜菜(sugarbeet)和雜草(weed)的數據集
在數據集的xml文件中包含了圖片中物體的位置形狀(x,y,w,h)和label
其中的一個xml文件
<annotation>
<folder>train</folder>
<filename>X2-10-1.png</filename>
<path /><source>
<database>Unknown</database>
</source>
<size>
<width>512</width>
<height>512</height>
<depth>3</depth>
</size>
<segmented>0</segmented>
<object>
<name>weed</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>71</xmin>
<ymin>265</ymin>
<xmax>115</xmax>
<ymax>278</ymax>
</bndbox>
</object>
......
<object>
<name>sugarbeet</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>322</xmin>
<ymin>266</ymin>
<xmax>363</xmax>
<ymax>294</ymax>
</bndbox>
</object>
</annotation>
現在我們要做的工作就是將這些數據儲存到numpy數組中去,代碼中我儘可能的寫了註釋,書寫這個的是否選擇了vscode作爲編譯工具,以爲vscode對於jupyter的支持較好,可以在編寫的過程中更加方便的查看每一步的運行結果
同時在編寫這一步的時候需要注意的一個點就是每個圖片中的物體個數可能不一樣,這樣我們的boxes的個數就有問題。因爲每個圖片中的框信息都沒有超過5個(上圖除外那是我自己畫的),所以我們每一張圖片都涉及有五個空,不足的就用0來填充
#%%
import tensorflow as tf
import os,glob
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import numpy as np
from tensorflow import keras
# set random seed
tf.random.set_seed(2233)
np.random.seed(2233)
# %%
print(tf.__version__)
print(tf.text.is_gpu_available())
# %%
import xml.etree.ElementTree as ET
def parse_annotation(img_dir,ann_dir,labels):
# parse annotation and save is into numpy array
# img_dir: image path
# ann_dir: annotation xml file path
# labels: ('sugarweet','weed')
imgs_info =[]
# for each annotation xml file
max_boxes = 0
for ann in os.listdir(ann_dir):
tree = ET.parse(os.path.join(ann_dir,ann))
img_info = dict()
img_info['object'] = []
boxes_counter = 0
for elem in tree.iter():
if 'filename' in elem.tag:
img_info['filename'] = os.path.join(img_dir,elem.text)
if 'width' in elem.tag:
img_info['width'] = int(elem.text)
assert img_info['width'] == 512
if 'height' in elem.tag:
img_info['height'] = int(elem.text)
assert img_info['width'] == 512
if 'object' in elem.tag or 'part' in elem.tag:
# x1-y1-x2-y2-label
object_info = [0,0,0,0,0]
boxes_counter += 1
for attr in list(elem):
# add image info into object_info
if 'name' in attr.tag:
label = labels.index(attr.text) + 1
object_info[4] = label
if 'bndbox' in attr.tag:
for pos in list(attr):
if 'xmin' in pos.tag:
object_info[0] = int(pos.text)
if 'ymin' in pos.tag:
object_info[1] = int(pos.text)
if 'xmax' in pos.tag:
object_info[2] = int(pos.text)
if 'ymax' in pos.tag:
object_info[3] = int(pos.text)
img_info['object'].append(object_info)
imgs_info.append(img_info) # filename,w/h/box_info
# (N,5) = (max_objects_num,5) 5 is x-y-w-h-label
if boxes_counter > max_boxes:
max_boxes = boxes_counter
# the maximum boxes number is max_boxes
# [b,max_things,5]
boxes = np.zeros([len(imgs_info),max_boxes,5])
imgs = [] # filename last
for i,img_info in enumerate(imgs_info):
# [N,5] N: boxes number
img_boxes = np.array(img_info['object'])
# overwrite the N boxes info
boxes[i,:img_boxes.shape[0]] = img_boxes
imgs.append(img_info['filename'])
print(img_info['filename'],boxes[i,:5])
# imgs: list of image path
# boxes:[b,40,5]
return imgs,boxes
# %%
obj_names = ('sugarbeet','weed')
imgs,boxes = parse_annotation('data/train/image','data/train/annotation',obj_names)