Win10手把手教你使用Tensorflow Object Detection API(SSD+MobileNetv2訓練VOC2007)

項目地址:https://github.com/tensorflow/models

這次可以和大家分享的是Tensorflow Object Detection API的簡單使用方法,以SSD+MobileNetv2的測試以及訓練來舉例,儘量以簡單的語言描述Tensorflow Object Detection API的使用方法。

Step1:環境配置

需要的環境可以參考installation.md,但是裏面沒有寫win10的配置方法,win10下我還是用Anaconda來配置

# 創建虛擬環境
conda create -n ssd python=3.6

# 激活環境
conda activate ssd

# 配置環境
conda install tensorflow-gpu==1.15.0
conda install Cython
conda install contextlib2
conda install pillow
conda install lxml
conda install jupyter
conda install matplotlib

安裝了以上的包之後,還有最後一步需要安裝pycocotools,不安裝這個是沒有辦法啓動訓練程序的

具體win10下pycocotools的安裝方法可以參考“win10安裝pycocotools

Step2:使用Protoc生成代碼

下載好項目之後,cd到D:\...\tensorflow\models\research這樣的路徑下(根據自己情況修改)

同時準備好protoc工具

protoc下載地址:鏈接:https://pan.baidu.com/s/1FJsrFVYBtG-cT6mnuKznOw  提取碼:wrtb 

找到protoc的路徑,執行命令:

D:/.../bin/protoc object_detection/protos/*.proto --python_out=.

Step3:設置路徑

在Anaconda文件夾中env/ssd/Lib/site-packages/路徑下(根據自己情況修改)

新建一個txt文件並改名爲tensorflow_model.pth

內容添加爲tensorflow中文件夾的路徑(根據自己情況修改)

D:\python\tensorflow\models\research
D:\python\tensorflow\models\research\slim
D:\python\tensorflow\models\research\object_detection

Step4:測試環境是否配置成功

 python object_detection/builders/model_builder_test.py

Step5:檢測一張圖片

import os
import sys
import tarfile

import cv2 as cv
import numpy as np
import tensorflow as tf
from utils import label_map_util
from utils import visualization_utils as vis_util


MODEL_NAME = 'ssd_mobilenet_v2_coco_2018_03_29'   # 預訓練模型的路徑,不用解壓
MODEL_FILE = 'D:/python/tensorflow/' + MODEL_NAME + '.tar.gz'

PATH_TO_FROZEN_GRAPH = MODEL_NAME + '/frozen_inference_graph.pb'

PATH_TO_LABELS = os.path.join('D:/python/tensorflow/models/research/object_detection/data', 'mscoco_label_map.pbtxt')

NUM_CLASSES = 90    # coco數據90類
tar_file = tarfile.open(MODEL_FILE)
for file in tar_file.getmembers():
    file_name = os.path.basename(file.name)
    if 'frozen_inference_graph.pb' in file_name:
        tar_file.extract(file, os.getcwd())

detection_graph = tf.Graph()
with detection_graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')

label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categorys = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
categorys_index = label_map_util.create_category_index(categorys)

def load_image_into_numpy(image):
    (im_w, im_h) = image.size
    return np.array(image.getdata()).reshape(im_h, im_w, 3).astype(np.uint8)

with detection_graph.as_default():
    with tf.Session(graph=detection_graph) as sess:
        image = cv.imread("D:/python/tensorflow/models/research/object_detection/test_images/image3.jpg")
        print(image.shape)
        image_np_expanded = np.expand_dims(image, axis=0)
        image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
        boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
        scores= detection_graph.get_tensor_by_name('detection_scores:0')
        classes = detection_graph.get_tensor_by_name('detection_classes:0')
        num_detections = detection_graph.get_tensor_by_name('num_detections:0')

        (boxes, scores, classes, num_detections) = sess.run([boxes, scores, classes, num_detections],
                                                            feed_dict={image_tensor: image_np_expanded})
        vis_util.visualize_boxes_and_labels_on_image_array(
            image,
            np.squeeze(boxes),
            np.squeeze(classes.astype(np.int32)),
            np.squeeze(scores),
            categorys_index,
            use_normalized_coordinates=True,
            line_thickness=4
        )
        cv.namedWindow("enhanced", 0)
        cv.resizeWindow("enhanced", 640, 750)
        cv.imshow("enhanced", image)
        cv.waitKey(0)
        cv.destroyAllWindows()

這就是檢測一張圖片的代碼,裏面用到了opencv讀圖,沒有的話可以用pip安裝一下,注意不要使用conda install

Step6:生成VOC2007的TF-Record

準備好VOC2007的數據集放在合適的路徑下,這裏我用訓練集來舉例生成TF-Record

python object_detection/dataset_tools/create_pascal_tf_record.py 
--label_map_path=D:/tensorflow/dataset/pascal_label_map.pbtxt 
--data_dir=D:/tensorflow/dataset/VOCdevkit 
--year=VOC2007 
--set=train 
--output_path=D:/tensorflow/dataset/tfrecord/pascal_train.record

Step7:修改配置文件

配置文件路徑:\...\research\object_detection\samples\configs\ssd_mobilenet_v2_coco.config

修改1:class的數目(coco的90修改成voc的20)

根據文件提示ctrl+F搜索"PATH_TO_BE_CONFIGURED" ,將所有這個字段的地方都修改成相應路徑

修改2:預訓練權重路徑fine_tune_checkpoint: (解壓指向model.ckpt)

修改3:修改train_input_reader和val_input_reader,修改成數據集的路徑

train_input_reader: {
  tf_record_input_reader {
    input_path: "D:/python/tensorflow/dataset/output/pascal_train.record"
  }
  label_map_path: "D:/python/tensorflow/dataset/output//pascal_label_map.pbtxt"
}

Step8:啓動訓練

 python object_detection/model_main.py 
--pipeline_config_path=D:/python/tensorflow/dataset/output/ssd_mobilenet_v2_coco.config 
--model_dir=D:\python\tensorflow\dataset\output\model --num_train_steps=10 
--num_eval_steps=5 
--alsologtostderr

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章