tensorflow+openCV進行目標檢測

一、準備

數據集:coco

模型:目標檢測常用的三個模型有:SSD、Faster R-CNN、YOLO

免去訓練的過程,模型成品下載:github地址

環境:TensorFlow 1.14.0、openCV 4.1.1

二、檢測

1、羅列類別名稱

person
bicycle
car
motorbike
aeroplane
bus
train
truck
boat
traffic light
fire hydrant

stop sign
parking meter
bench
bird
cat
dog
horse
sheep
cow
elephant
bear
zebra
giraffe

backpack
umbrella


handbag
tie
suitcase
frisbee
skis
snowboard
sports ball
kite
baseball bat
baseball glove
skateboard
surfboard
tennis racket
bottle

wine glass
cup
fork
knife
spoon
bowl
banana
apple
sandwich
orange
broccoli
carrot
hot dog
pizza
donut
cake
chair
sofa
potted plant
bed

dining table


toilet

tv monitor
laptop
mouse
remote
keyboard
cell phone
microwave
oven
toaster
sink
refrigerator

book
clock
vase
scissors
teddy bear
hair drier
toothbrush

 2、下載模型

比如下載   faster_rcnn_inception_v2_coco_2018_01_28  解壓到當前目錄下

準備好待檢測圖片

import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
import tensorflow as tf

# 加載coco數據集模型
model_path = "faster_rcnn_inception_v2_coco_2018_01_28"
frozen_pb_file = os.path.join(model_path, 'frozen_inference_graph.pb')

# 加載coco數據集分類
f = open("coco/classes.txt", "r")
class_names = f.readlines()

# model_path = ""
# frozen_pb_file = os.path.join(model_path, 'model.pb')


score_threshold = 0.3

img_file = 'pic/class.jpg'

# Read the graph.
with tf.gfile.FastGFile(frozen_pb_file, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())


with tf.Session() as sess:
    # Restore session
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')

    # for op in sess.graph.get_operations():
    #     print(op)

    # Read and preprocess an image.
    img_cv2 = cv2.imread(img_file)
    img_height, img_width, _ = img_cv2.shape

    img_in = cv2.resize(img_cv2, (300, 300))
    img_in = img_in[:, :, [2, 1, 0]]  # BGR2RGB

    # Run the model
    outputs = sess.run([sess.graph.get_tensor_by_name('num_detections:0'),
                    sess.graph.get_tensor_by_name('detection_scores:0'),
                    sess.graph.get_tensor_by_name('detection_boxes:0'),
                    sess.graph.get_tensor_by_name('detection_classes:0')],
                   feed_dict={
                       'image_tensor:0': img_in.reshape(1,
                                                        img_in.shape[0],
                                                        img_in.shape[1],
                                                        3)})

    # Visualize detected bounding boxes.
    num_detections = int(outputs[0][0])
    for i in range(num_detections):
        classId = int(outputs[3][0][i])
        score = float(outputs[1][0][i])
        bbox = [float(v) for v in outputs[2][0][i]]
        if score > score_threshold:
            x = bbox[1] * img_width
            y = bbox[0] * img_height
            right = bbox[3] * img_width
            bottom = bbox[2] * img_height
            # 標框
            cv2.rectangle(img_cv2,
                          (int(x), int(y)),
                          (int(right), int(bottom)),
                          (125, 255, 51),
                          thickness=3)
            # 文字"class_name, score"
            cv2.putText(img_cv2,
                        class_names[classId - 1][:-1] + "," + str("%.2f" % score),
                        (int(x), int(y)),
                        cv2.FONT_HERSHEY_DUPLEX, 3, (0, 0, 255), 3)
            print(str(classId) + ",class:" + class_names[classId - 1][:-1] + ",score:%.2f" % score)

plt.figure(figsize=(10, 8))
plt.imshow(img_cv2[:, :, ::-1])
plt.title("TensorFlow MobileNetV2-SSD")
plt.axis("off")
plt.show()

三、後期展望

使用新的數據訓練自己的model,進行SSD或者Faster R-CNN模型的遷移學習,運用到更具體的場景中去。

四、參考文獻

【1】TensorFlow 目標檢測模型轉換爲 OpenCV DNN 可調用格式

【2】SSD模型的原理

【3】coco2017數據集80類別名稱與id號的對應關係

發佈了106 篇原創文章 · 獲贊 42 · 訪問量 9萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章