Tensorflow object_detection API(一)

Tensorflow object detection API是基於tensorflow的開源框架,可以用於搭建、訓練、使用object detection服務。

github傳送門

object_detection隸屬於Tensorflow models下的research,在下載object_detection的同時,建議下載整個models,有些包並不在object_detection中,而是在同級目錄下。

安裝教程

object_detection API依賴於包protobuf、pillow、lxml、jupyter、matplotlib

這些包在安裝的過程中有很大可能回報錯,其中最可能是linux系統沒有安裝gcc,或者gcc版本過低或過高。

Tensorflow Object Detection API使用Protobufs來配置模型和訓練參數。在使用框架之前,必須編譯Protobuf庫。這應該通過從下載解壓的models/目錄運行以下命令來完成:

protoc object_detection/protos/*.proto –python_out=.

當在本地運行時,models /和slim目錄應該附加到PYTHONPATH。在查閱了很多資料後,大概有以下幾種方法:
1. 在python的site-package中添加.pth文件,將models和slim文件路徑添加
2. 在python代碼中添加

import sys
sys.path.append('models路徑')
sys.path.append('slim路徑')

以上安裝完畢

安裝測試

可以通過運行以下命令來測試是否正確安裝了Tensorflow Object Detection API:

python object_detection / builders / model_builder_test.py

MSCOCO模型測試

MSCOCO是Microsoft下的coco數據集。有多種物品及其標記,教程中給了SSDmobilenet的模型下載(據說ssd_mobilenet是最快的,但精度最低)

測試代碼位於object_detection文件中的object_detection_tutorial.ipynb
(.ipynb使用notebook打開)。裏面有很詳細的教程。測試圖像結果爲:

這裏寫圖片描述

視頻實現

安裝python-opencv(使用apt-get會很簡單)後,目前實現的是單線程的物體檢測,以下是全部代碼:

import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile

from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image

if tf.__version__ < '1.4.0':
  raise ImportError('Please upgrade your tensorflow installation to v1.4.* or later!')
# This is needed to display the images.
%matplotlib inline

# This is needed since the notebook is stored in the object_detection folder.
sys.path.append("..")

from utils import label_map_util

from utils import visualization_utils as vis_util

# What model to download.
MODEL_NAME = 'ssd_mobilenet_v1_coco_2017_11_17'
MODEL_FILE = MODEL_NAME + '.tar.gz'
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'

# Path to frozen detection graph. This is the actual model that is used for the object detection.
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'

# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')

NUM_CLASSES = 90
#如果是已經下載好的模型,可以註釋掉這一段
opener = urllib.request.URLopener()
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)
tar_file = tarfile.open(MODEL_FILE)
for file in tar_file.getmembers():
  file_name = os.path.basename(file.name)
  if 'frozen_inference_graph.pb' in file_name:
    tar_file.extract(file, os.getcwd())

detection_graph = tf.Graph()
with detection_graph.as_default():
  od_graph_def = tf.GraphDef()
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
    serialized_graph = fid.read()
    od_graph_def.ParseFromString(serialized_graph)
    tf.import_graph_def(od_graph_def, name='')

label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)

def load_image_into_numpy_array(image):
  (im_width, im_height) = image.size
  return np.array(image.getdata()).reshape(
      (im_height, im_width, 3)).astype(np.uint8)
#以下是修改教程後的代碼,和教程有所區別
import cv2
cap = cv2.VideoCapture(0)  # 打開0號攝像頭
success = True
font = cv2.FONT_HERSHEY_SIMPLEX
with detection_graph.as_default():
  with tf.Session(graph=detection_graph) as sess:
    # Definite input and output Tensors for detection_graph
    image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
    # Each box represents a part of the image where a particular object was detected.
    detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
    # Each score represent how level of confidence for each of the objects.
    # Score is shown on the result image, together with the class label.
    detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
    detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
    num_detections = detection_graph.get_tensor_by_name('num_detections:0')
    def returnimage(image_np):
        image_np_expanded = np.expand_dims(image_np, axis=0)
      # Actual detection.
        (boxes, scores, classes, num) = sess.run(
            [detection_boxes, detection_scores, detection_classes, num_detections],
            feed_dict={image_tensor: image_np_expanded})
      # Visualization of the results of a detection.
        vis_util.visualize_boxes_and_labels_on_image_array(
            image_np,
            np.squeeze(boxes),
            np.squeeze(classes).astype(np.int32),
            np.squeeze(scores),
            category_index,
            use_normalized_coordinates=True,
            line_thickness=8)
        return np.array(image)
    while success:
      success, image = cap.read()
      image = returnimage(image)
      cv2.imshow("test", image)
      if cv2.waitKey(1) & 0xFF == ord('q'):
        cv2.imwrite('test.jpg',image)
        break
  cap.release()
  cv2.destroyAllWindows()

運行結果截圖:
這裏寫圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章