用Tensorflow Object Detection API 訓練自己的數據集

一、準備數據集

Tensorflow Object Detection API 用 TFRecord 文件格式讀取數據,需把 VOC 格式的數據集進行轉換(我自己的數據集是VOC2007)

1、修改 tensorflow/models/object_detection/create_pascal_tf_record.py 文件第84行和162行。
這裏寫圖片描述

這裏寫圖片描述

2、修改tensorflow/models/object_detection/data/pascal_label_map.pbtxt 文件裏的類別.
3、運行命令:

# From tensorflow/models
python object_detection/create_pascal_tf_record.py \
    --label_map_path=object_detection/data/pascal_label_map.pbtxt \
    --data_dir=VOCdevkit --year=VOC2007 --set=train \
    --output_path=pascal_train.record
python object_detection/create_pascal_tf_record.py \
    --label_map_path=object_detection/data/pascal_label_map.pbtxt \
    --data_dir=VOCdevkit --year=VOC2007 --set=val \
    --output_path=pascal_val.record

執行後會在object_detection文件夾下生成pascal_train.record和pascal_val.record兩個文件。

二、下載預訓練模型

下載地址:https://github.com/tensorflow/models/blob/master/object_detection/g3doc/detection_model_zoo.md
解壓命令例子:

tar -xzvf ssd_mobilenet_v1_coco.tar.gz

三、修改配置文件

修改 object_detection/samples/configs/faster_rcnn_inception_resnet_v2_atrous_pets.config文件:

(1)num_classes:修改爲自己的classes num

這裏寫圖片描述

(2)將所有PATH_TO_BE_CONFIGURED的地方修改爲自己之前設置的路徑(共5處)

這裏寫圖片描述

這裏寫圖片描述

四、訓練

進入object_detection目錄,運行:

tensorflow/models$ python object_detection/train.py --train_dir='/home/anngic/tensorflow/train' --pipeline_config_path='/home/anngic/tensorflow/models/object_detection/samples/configs/faster_rcnn_inception_resnet_v2_atrous_coco.config'

五、tensorboad

輸入命令:

tensorboard --logdir=/home/shz/TF-OD-Test/train

在瀏覽器中輸入https://0.0.0.0:6006,就能看到訓練曲線了。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章