Tensorflow YOLO代碼解析(2)

下面介紹數據集處理部分代碼。訓練數據的處理主要包括兩部分,一是在模型訓練前使用數據預處理腳本preprocess_pascal_voc.py 對下載得到的數據集處理得到一個訓練樣本信息列表文件。二是在模型訓練時根據訓練樣本的信息列表文件將數據讀入到隊列中,供模型訓練時讀取batch數據使用。

其他相關的部分請見:
YOLO代碼解析(1) 代碼總覽與使用
YOLO代碼解析(2) 數據處理
YOLO代碼解析(3) 模型和損失函數
YOLO代碼解析(4) 訓練和測試代碼

1.preprocess_pascal_voc.py :數據預處理

pascal_voc數據集的標註數據保存在xml中,每張圖片對應一個單獨的xml文件,文件內容如:

<annotation>
	<folder>VOC2007</folder>
	<filename>000001.jpg</filename>
	<source>
		<database>The VOC2007 Database</database>
		<annotation>PASCAL VOC2007</annotation>
		<image>flickr</image>
		<flickrid>341012865</flickrid>
	</source>
	<owner>
		<flickrid>Fried Camels</flickrid>
		<name>Jinky the Fruit Bat</name>
	</owner>
	<size>
		<width>353</width>
		<height>500</height>
		<depth>3</depth>
	</size>
	<segmented>0</segmented>
	<object>
		<name>dog</name>
		<pose>Left</pose>
		<truncated>1</truncated>
		<difficult>0</difficult>
		<bndbox>
			<xmin>48</xmin>
			<ymin>240</ymin>
			<xmax>195</xmax>
			<ymax>371</ymax>
		</bndbox>
	</object>
	<object>
		<name>person</name>
		<pose>Left</pose>
		<truncated>1</truncated>
		<difficult>0</difficult>
		<bndbox>
			<xmin>8</xmin>
			<ymin>12</ymin>
			<xmax>352</xmax>
			<ymax>498</ymax>
		</bndbox>
	</object>
</annotation>

本腳本功能是解析xml文件,對每一張圖片,得到一條形如[image_path xmin1 ymin1 xmax1 ymax1 class_id1 xmin2 ymin2 xmax2 ymax2 class_id2] (如:/home/jerry/tensorflow-yolo/data/VOCdevkit/VOC2007/JPEGImages/009960.jpg 26 140 318 318 13 92 46 312 267 14)的記錄,並寫入到文件中。

xml解析代碼:

def parse_xml(xml_file):
  """
  解析 xml_文件
  輸入:xml文件路徑
  返回:圖像路徑和對應的label信息
  """
  # 使用ElementTree解析xml文件
  tree = ET.parse(xml_file)
  root = tree.getroot()
  image_path = ''
  labels = []

  for item in root:
    if item.tag == 'filename':
      image_path = os.path.join(DATA_PATH, 'VOC2007/JPEGImages', item.text)
    elif item.tag == 'object':
      obj_name = item[0].text
      # 將objetc的名稱轉換爲ID
      obj_num = classes_num[obj_name]
      # 依次得到Bbox的左上和右下點的座標
      xmin = int(item[4][0].text)
      ymin = int(item[4][1].text)
      xmax = int(item[4][2].text)
      ymax = int(item[4][3].text)
      labels.append([xmin, ymin, xmax, ymax, obj_num])

  # 返回圖像的路徑和label信息(Bbox座標和類別ID)
  return image_path, labels

def convert_to_string(image_path, labels):
  """
     將圖像的路徑和lable信息轉爲string
  """
  out_string = ''
  out_string += image_path
  for label in labels:
    for i in label:
      out_string += ' ' + str(i)
  out_string += '\n'
  return out_string

def main():
  out_file = open(OUTPUT_PATH, 'w')

  # 獲取所有的xml標註文件的路徑
  xml_dir = DATA_PATH + '/VOC2007/Annotations/'
  xml_list = os.listdir(xml_dir)
  xml_list = [xml_dir + temp for temp in xml_list]

  # 解析xml文件,得到圖片名稱和lables,並轉換得到圖片的路徑
  for xml in xml_list:
    try:
      image_path, labels = parse_xml(xml)
      # 將解析得到的結果轉爲string並寫入文件
      record = convert_to_string(image_path, labels)
      out_file.write(record)
    except Exception:
      pass

  out_file.close()

2. text_dataset.py:準備訓練用batch數據

主要將在訓練過程中將訓練數據讀入到隊列中,起到緩存的作用。

class TextDataSet(DataSet):
  """TextDataSet
     對數據預處理中得到的data list文件進行處理
     text file format:
     image_path xmin1 ymin1 xmax1 ymax1 class1 xmin2 ymin2 xmax2 ymax2 class2
  """

  def __init__(self, common_params, dataset_params):
    """
    Args:
      common_params: A dict
      dataset_params: A dict
    """
    #process params
    self.data_path = str(dataset_params['path'])
    self.width = int(common_params['image_size'])
    self.height = int(common_params['image_size'])
    self.batch_size = int(common_params['batch_size'])
    self.thread_num = int(dataset_params['thread_num'])
    self.max_objects = int(common_params['max_objects_per_image'])

    #定義兩個隊列,一個存放訓練樣本的list,另個存放訓練樣本的數據(image & label)
    self.record_queue = Queue(maxsize=10000)
    self.image_label_queue = Queue(maxsize=512)

    self.record_list = []  

    # 讀取經過數據預處理得到的 pascal_voc.txt
    input_file = open(self.data_path, 'r')

    for line in input_file:
      line = line.strip()
      ss = line.split(' ')
      ss[1:] = [float(num) for num in ss[1:]]  # 將座標和類別ID轉爲float
      self.record_list.append(ss)

    self.record_point = 0
    self.record_number = len(self.record_list)

    # 計算每個epoch的batch數目
    self.num_batch_per_epoch = int(self.record_number / self.batch_size)

    # 啓動record_processor進程
    t_record_producer = Thread(target=self.record_producer)
    t_record_producer.daemon = True 
    t_record_producer.start()

    # 啓動record_customer進程
    for i in range(self.thread_num):
      t = Thread(target=self.record_customer)
      t.daemon = True
      t.start() 

  def record_producer(self):
    """record_queue 的processor
    """
    while True:
      if self.record_point % self.record_number == 0:
        random.shuffle(self.record_list)
        self.record_point = 0
      # 從record_list讀取一條訓練樣本信息到record_queue
      self.record_queue.put(self.record_list[self.record_point])
      self.record_point += 1

  def record_process(self, record):
    """record 處理過程
    Args: record 
    Returns:
      image: 3-D ndarray
      labels: 2-D list [self.max_objects, 5] (xcenter, ycenter, w, h, class_num)
      object_num:  total object number  int 
    """
    image = cv2.imread(record[0])  # record[0]是image 的路徑

    # 對圖像做色彩空間變換和尺寸縮放
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    h = image.shape[0]
    w = image.shape[1]

    width_rate = self.width * 1.0 / w 
    height_rate = self.height * 1.0 / h

    # 尺寸調整到 (448,448)
    image = cv2.resize(image, (self.height, self.width))

    labels = [[0, 0, 0, 0, 0]] * self.max_objects

    i = 1
    object_num = 0

    while i < len(record):
      xmin = record[i]
      ymin = record[i + 1]
      xmax = record[i + 2]
      ymax = record[i + 3]
      class_num = record[i + 4]
     
      # 由於圖片縮放過,對label座標做同樣處理
      xcenter = (xmin + xmax) * 1.0 / 2 * width_rate
      ycenter = (ymin + ymax) * 1.0 / 2 * height_rate

      box_w = (xmax - xmin) * width_rate
      box_h = (ymax - ymin) * height_rate

      labels[object_num] = [xcenter, ycenter, box_w, box_h, class_num]
      object_num += 1
      i += 5
      if object_num >= self.max_objects:
        break
    return [image, labels, object_num]

  def record_customer(self):
    """record queue的使用者
       取record queue中數據,經過處理後,送到image_label_queue中
    """
    while True:
      item = self.record_queue.get()
      out = self.record_process(item)
      self.image_label_queue.put(out)

下一篇:YOLO代碼解析(3) 模型和損失函數

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章