在VOC格式的SSD訓練數據標籤中添加圖片寬高等信息,並進行座標越界檢查。
# coding: utf-8
import os
from xml.etree.ElementTree import ElementTree,Element
from PIL import Image
def read_xml(in_path):
'''''讀取並解析xml文件
in_path: xml路徑
return: ElementTree'''
tree = ElementTree()
tree.parse(in_path)
return tree
def write_xml(tree, out_path):
'''''將xml文件寫出
tree: xml樹
out_path: 寫出路徑'''
tree.write(out_path, encoding="utf-8", xml_declaration=True)
def if_match(node, kv_map):
'''''判斷某個節點是否包含所有傳入參數屬性
node: 節點
kv_map: 屬性及屬性值組成的map'''
for key in kv_map:
if node.get(key) != kv_map.get(key):
return False
return True
# ----------------search -----------------
def find_nodes(tree, path):
'''''查找某個路徑匹配的所有節點
tree: xml樹
path: 節點路徑'''
return tree.findall(path)
def get_node_by_keyvalue(nodelist, kv_map):
'''''根據屬性及屬性值定位符合的節點,返回節點
nodelist: 節點列表
kv_map: 匹配屬性及屬性值map'''
result_nodes = []
for node in nodelist:
if if_match(node, kv_map):
result_nodes.append(node)
return result_nodes
# ---------------change ----------------------
def change_node_properties(nodelist, kv_map, is_delete=False):
'''修改/增加 /刪除 節點的屬性及屬性值
nodelist: 節點列表
kv_map:屬性及屬性值map'''
for node in nodelist:
for key in kv_map:
if is_delete:
if key in node.attrib:
del node.attrib[key]
else:
node.set(key, kv_map.get(key))
def change_node_text(nodelist, text, is_add=False, is_delete=False):
'''''改變/增加/刪除一個節點的文本
nodelist:節點列表
text : 更新後的文本'''
for node in nodelist:
if is_add:
node.text += text
elif is_delete:
node.text = ""
else:
node.text = text
def create_node(tag, property_map, content):
'''新造一個節點
tag:節點標籤
property_map:屬性及屬性值map
content: 節點閉合標籤裏的文本內容
return 新節點'''
element = Element(tag, property_map)
element.text = content
return element
def add_child_node(nodelist, element):
'''''給一個節點添加子節點
nodelist: 節點列表
element: 子節點'''
for node in nodelist:
node.append(element)
def del_node_by_tagkeyvalue(nodelist, tag, kv_map):
'''''同過屬性及屬性值定位一個節點,並刪除之
nodelist: 父節點列表
tag:子節點標籤
kv_map: 屬性及屬性值列表'''
for parent_node in nodelist:
children = parent_node.getchildren()
for child in children:
if child.tag == tag and if_match(child, kv_map):
parent_node.remove(child)
if __name__ == "__main__":
path = "E:\\UnderWaterDetection\\train\\box" # xml文件所在的目錄
files = os.listdir(path) # 得到文件夾下所有文件名稱
s = []
for xmlFile in files: # 遍歷文件夾
xmlPath = os.path.join(path, xmlFile)
#xmlPath = "E:\\UnderWaterDetection\\train\\box\\000001.xml"
xmlPartName = xmlFile.split(".")[0]
imageFile = xmlPartName + ".jpg"
imagePath = os.path.join("E:\\UnderWaterDetection\\train\\image\\", imageFile)
img = Image.open(imagePath)
################ 1. 讀取xml文件 ##########
tree = read_xml(xmlPath)
root = tree.getroot()
################ 2. 屬性修改 ###############
#nodes = find_nodes(tree, "object") # 找到父節點
for obj in root.iter('object'): # 獲取object節點中的name子節點
obj.find("bndbox")
xmin_node = bnd_node.find("xmin")
ymin_node = bnd_node.find("ymin")
xmax_node = bnd_node.find("xmax")
ymax_node = bnd_node.find("ymax")
if int(xmin_node.text) <= 0 :
xmin_node.text = "0"
if int(ymin_node.text) <= 0 :
ymin_node.text = "0"
if int(xmax_node.text) > img.size[0] :
xmax_node.text = str(img.size[0])
if int(ymax_node.text) > img.size[1] :
ymax_node.text = str(img.size[1])
#result_nodes = get_node_by_keyvalue(nodes, {"name": "BProcesser"}) # 通過屬性準確定位子節點
#change_node_properties(result_nodes, {"age": "1"}) # 修改節點屬性
#change_node_properties(result_nodes, {"value": ""}, True) # 刪除節點屬性
################# 3. 節點修改 ##############
a = create_node("size", {}, "") # 新建節點
w = create_node("width", {}, str(img.size[0])) # 新建節點
a.append(w)
h = create_node("height", {}, str(img.size[1])) # 新建節點
a.append(h)
d = create_node("depth", {}, "3") # 新建節點
a.append(d)
#add_child_node(a, w) # 插入到父節點之下
#add_child_node(a, h) # 插入到父節點之下
#add_child_node(a, d) # 插入到父節點之下
#add_child_node(root, a) # 插入到父節點之下
root.append(a)
################# 4. 刪除節點 ################
#del_parent_nodes = find_nodes(tree, "processers/services/service") # 定位父節點
#target_del_node = del_node_by_tagkeyvalue(del_parent_nodes, "chain", {"sequency": "chain1"}) # 準確定位子節點並刪除之
################# 5. 修改節點文本 ############
#text_nodes = get_node_by_keyvalue(find_nodes(tree, "processers/services/service/chain"), {"sequency": "chain3"}) # 定位節點
#change_node_text(text_nodes, "new text")
################ 6. 輸出到結果文件 ##########
savePath = os.path.join("E:\\UnderWaterDetection\\train\\box1\\", xmlFile)
write_xml(tree, savePath)
參考文獻: