【目標檢測項目實戰】一、Tensorflow Object Detection API 下載與配置及使用

首先,簡單介紹下,Tensorflow Object Detection API是一個構建在TensorFlow之上的開源框架,它使構建、訓練和部署對象檢測模型變得很容易

首先,關於win10下深度學習基本環境的搭建,比如,anaconda, Tensorflow CPU或GPU版本,pycharm等安裝這塊就不說了,網上的教程很多。

額外需要的python庫有 pillow,  lxml,可以通過pip install 命令進行安裝

1.Tensorflow Object Detection API 下載

    https://github.com/tensorflow/models,直接從github上下載源碼

2.Protoc下載

Protoc是用來將下載來的 中的 object_detection/protos目錄下的proto文件編譯爲py文件

WIN下,建議下載3.4的版本,下載鏈接

下載完成後,將對應目錄的bin文件夾目錄添加到環境變量中

cmd打開命令行,輸入 protoc,顯示如下內容說明安裝成功

 

3.object_detection\protos目錄下的文件編譯

將之前下載好的Tensorflow Object Detection文件解壓,命令行cd進入models-master\research目錄下,然後執行命令

protoc ./object_detection/protos/*.proto --python_out=. 

將object_detection/protos目錄下的proto文件編譯爲py文件,

執行完畢後,進入object_detection/protos目錄下查看,可以看到生成了對應的py文件

 

4.使用訓練好的目標檢測模型完成目標檢測任務

首先,在Pycharm中重新創建一個你的新項目,我這塊項目名稱爲 using_pre-trained_model_to_detect_objects,然後將下載的Tensorflow Object Detection中的models-master\research\object_detection拷貝進using_pre-trained_model_to_detect_objects新項目中

在項目中創建  object_detection_tutorial.py 文件用來進行目標檢測,項目結構爲:

預測程序如下,需要注意相關路徑問題:

import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile

from distutils.version import StrictVersion
from collections import defaultdict
from io import StringIO
import matplotlib.pyplot as plt
from PIL import Image

from object_detection.utils import ops as utils_ops

if StrictVersion(tf.__version__) < StrictVersion('1.12.0'):
  raise ImportError('Please upgrade your TensorFlow installation to v1.12.*.')

from object_detection.utils import label_map_util

from object_detection.utils import visualization_utils as vis_util

MODEL_NAME = 'ssd_mobilenet_v1_coco_2017_11_17'
MODEL_FILE = MODEL_NAME + '.tar.gz'
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'

# pb模型存放位置.
PATH_TO_FROZEN_GRAPH = MODEL_NAME + '/frozen_inference_graph.pb'

# coco數據集的label映射文件
PATH_TO_LABELS = os.path.join('object_detection/data', 'mscoco_label_map.pbtxt')

PATH_TO_TEST_IMAGES_DIR = 'object_detection/test_images'
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ]

#模型下載與解壓
def downloadModel():
  opener = urllib.request.URLopener()
  opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)
  tar_file = tarfile.open(MODEL_FILE)
  for file in tar_file.getmembers():
    file_name = os.path.basename(file.name)
    if 'frozen_inference_graph.pb' in file_name:
      tar_file.extract(file, os.getcwd())


#加載模型
def loadModel():
    detection_graph = tf.Graph()
    with detection_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')
    return detection_graph

#將圖片轉換爲三維數組,數據類型爲uint8
def load_image_into_numpy_array(image):
  (im_width, im_height) = image.size
  return np.array(image.getdata()).reshape(
      (im_height, im_width, 3)).astype(np.uint8)

#進行目標檢測
def run_inference_for_single_image(image, graph):
  with graph.as_default():
    with tf.Session() as sess:
      # Get handles to input and output tensors
      ops = tf.get_default_graph().get_operations()
      all_tensor_names = {output.name for op in ops for output in op.outputs}
      tensor_dict = {}
      for key in [
          'num_detections', 'detection_boxes', 'detection_scores',
          'detection_classes'
      ]:
        tensor_name = key + ':0'
        if tensor_name in all_tensor_names:
          tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(
              tensor_name)
      image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')

      # Run inference
      output_dict = sess.run(tensor_dict,
                             feed_dict={image_tensor: image})

      # all outputs are float32 numpy arrays, so convert types as appropriate
      output_dict['num_detections'] = int(output_dict['num_detections'][0])
      output_dict['detection_classes'] = output_dict[
          'detection_classes'][0].astype(np.int64)
      output_dict['detection_boxes'] = output_dict['detection_boxes'][0]
      output_dict['detection_scores'] = output_dict['detection_scores'][0]
  return output_dict

def predict(detection_graph):
    for image_path in TEST_IMAGE_PATHS:
        image = Image.open(image_path)
        # the array based representation of the image will be used later in order to prepare the
        # result image with boxes and labels on it.
        image_np = load_image_into_numpy_array(image)
        # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
        image_np_expanded = np.expand_dims(image_np, axis=0)
        # Actual detection.
        output_dict = run_inference_for_single_image(image_np_expanded, detection_graph)
        # 得到一個保存編號和類別描述映射關係的列表
        category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)
        # Visualization of the results of a detection.
        vis_util.visualize_boxes_and_labels_on_image_array(
            image_np,
            output_dict['detection_boxes'],
            output_dict['detection_classes'],
            output_dict['detection_scores'],
            category_index,
            instance_masks=output_dict.get('detection_masks'),
            use_normalized_coordinates=True,
            line_thickness=8)
        plt.figure(figsize=(12, 8))
        plt.imshow(image_np)
        plt.axis('off')
        plt.show()


if __name__ == '__main__':
    # downloadModel()
    detection_graph = loadModel()
    predict(detection_graph)

輸出結果爲:

 

可以看到,成功檢測到了相關物體。

 

歡迎關注我的個人公衆號 AI計算機視覺工坊,本公衆號不定期推送機器學習,深度學習,計算機視覺等相關文章,歡迎大家和我一起學習,交流。

                                 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章