使用Retinanet訓練自己的數據集

目錄

 

目錄

1 構建Retinanet環境

2 生成CSV文件

3訓練

4.轉化模型

5.測試

6.評測

loss可視化

ap,precision-recall


數據集什麼的看我之前博客,資源裏也有標記好的數據集,這裏主要寫一下我配置使用訓練過程。

1 構建Retinanet環境

1.代碼庫下載地址https://github.com/fizyr/keras-retinanet,或git命令:

git clone https://github.com/fizyr/keras-retinanet.git

2.獲得代碼庫後進入keras-retinanet文件夾,確認有未安裝numpy
 

cd keras-retinanet

pip install numpy --user

在這個文件夾內運行下面代碼來安裝keras-retinanet庫,確認你已經根據自己的系統需求安裝了tensorflow

pip install . --user
python setup.py build_ext --inplace

2 生成CSV文件

訓練自己的數據集需要至少兩個CSV文件,一個文件包含標註數據,另一個則包含各個類別名及其對應的ID序號映射。

先拋出我的文件位置,新建一個csv文件夾,data文件裏放置的是訓練圖片及標籤

三個csv就是我們要生成的

參考博客https://blog.csdn.net/qq_27171347/article/details/88878346

"""
進入到csv文件夾下
運行方式:命令行 python xml2csv.py -i indir(圖片及標註的母目錄)
      注:必須參數: -i 指定包含有圖片及標註的母文件夾,圖片及標註可不在同一子目錄裏,但名稱必須一一對應
                     (圖片格式默認.jpg,若爲其他格式可見代碼中註釋自行修改)
          可選參數: -p 交叉驗證集拆分比,默認0.05
                   -t 生成訓練集CSV文件名稱,默認train.csv
                   -v 生成交叉驗證集CSV文件名稱,默認val.csv
                   -c 生成類別映射CSV文件名稱,默認class.csv
"""

import os
import xml.etree.ElementTree as ET
import random
import math
import argparse


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--indir', type=str)
    parser.add_argument('-p', '--percent', type=float, default=0.1) #0.05
    parser.add_argument('-t', '--train', type=str, default='train.csv')
    parser.add_argument('-v', '--val', type=str, default='val.csv')
    parser.add_argument('-c', '--classes', type=str, default='class.csv')
    args = parser.parse_args()
    return args

#獲取特定後綴名的文件列表
def get_file_index(indir, postfix):
    file_list = []
    for root, dirs, files in os.walk(indir):
        for name in files:
            if postfix in name:
                file_list.append(os.path.join(root, name))
    return file_list

#寫入標註信息
def convert_annotation(csv, address_list):
    cls_list = []
    with open(csv, 'w') as f:
        for i, address in enumerate(address_list):
            in_file = open(address, encoding='utf8')
            strXml =in_file.read()
            in_file.close()
            root=ET.XML(strXml)
            for obj in root.iter('object'):
                cls = obj.find('name').text
                cls_list.append(cls)
                xmlbox = obj.find('bndbox')
                b = (int(xmlbox.find('xmin').text), int(xmlbox.find('ymin').text), 
                     int(xmlbox.find('xmax').text), int(xmlbox.find('ymax').text))
                f.write(file_dict[address_list[i]])
                f.write( "," + ",".join([str(a) for a in b]) + ',' + cls)
                f.write('\n')
    return cls_list


if __name__ == "__main__":
    args = parse_args()
    file_address = args.indir
    test_percent = args.percent
    train_csv = args.train
    test_csv = args.val
    class_csv = args.classes
    Annotations = get_file_index(file_address, '.xml')
    Annotations.sort()
    JPEGfiles = get_file_index(file_address, '.JPG') #可根據自己數據集圖片後綴名修改
    JPEGfiles.sort()
    assert len(Annotations) == len(JPEGfiles) #若XML文件和圖片文件名不能一一對應即報錯
    file_dict = dict(zip(Annotations, JPEGfiles))
    num = len(Annotations)
    test = random.sample(k=math.ceil(num*test_percent), population=Annotations)
    train = list(set(Annotations) - set(test))

    cls_list1 = convert_annotation(train_csv, train)
    cls_list2 = convert_annotation(test_csv, test)
    cls_unique = list(set(cls_list1+cls_list2))

    with open(class_csv, 'w') as f:
        for i, cls in enumerate(cls_unique):
            f.write(cls + ',' + str(i) + '\n')

進入csv文件夾下,python xml2csv.py -i /home/zbb/keras-retinanet/CSV/data

class.csv:

#類別,序號(從0開始)
#class_name,id

plane,0

train.csv:

#路徑,xmin,ymin,xmax,ymax,類別名
#path/to/image.jpg,x1,y1,x2,y2,class_name
/data/imgs/img_001.jpg,837,346,981,456,plane

 

3訓練

csv 後第一個參數接標註csv文件路徑 , 第二個接類別映射csv文件路徑, 第三個參數可選擇添加交叉驗證集
示例:retinanet-train csv ./train.csv ./class.csv --val-annotations ./val.csv

一般還需指定 --epochs 訓練輪數 默認值50
            --batch-size 一批訓練多少個 默認值1
            --steps 一輪訓練多少步 默認10000 需按照自己數據集size大小計算 steps = size / batch-size

至於是否加載權重訓練,backbone選擇(默認Resnet50,可選參見keras_retinanet/models),學習率大小等按照自己需要進行指定。
有250個樣本, batch_size=1,訓練100輪, 則命令如下:

retinanet-train --batch-size 1 --steps 250 --epochs 100 csv ./train.csv ./class.csv --val-annotations ./val.csv

4.轉化模型


retinanet-convert-model 訓練出的模型地址 轉化後的推斷模型地址

retinanet-convert-model ./snapshots/resnet50_csv_100.h5 ./model/resnet50_csv_100.h5

5.測試

返回上一層文件夾,即是keras-retinanet下,新建test.py文件運行測試,進行修改,可測試保存多張圖片

import keras
from keras_retinanet import models
from keras_retinanet.utils.image import read_image_bgr, preprocess_image, resize_image
from keras_retinanet.utils.visualization import draw_box, draw_caption
from keras_retinanet.utils.colors import label_color

import matplotlib.pyplot as plt
import cv2
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import numpy as np
import time

import tensorflow as tf

def get_session():
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    return tf.Session(config=config)

# 設置tensorflow session 爲Keras 後端
keras.backend.tensorflow_backend.set_session(get_session())
#轉化後的推斷模型地址
#model_path = os.path.join('..', 'snapshots', 'predict.h5')
model_path ='/home/zbb/keras-retinanet/CSV/model/resnet50_csv_100.h5'
#加載模型
model = models.load_model(model_path, backbone_name='resnet50')
#建立ID與類別映射字典
labels_to_names = {0: 'plane'}
#加載需要檢測的圖片
#image_path = '/home/zbb/keras-retinanet/CSV/test/50.JPG'
path='/home/zbb/keras-retinanet/CSV/test/'
save_path='/home/zbb/keras-retinanet/CSV/result/'
image_names = sorted(os.listdir(path))
for image_path in image_names:
	image = read_image_bgr(path+image_path)
	print(path+image_path)
	# copy到另一個對象並轉爲RGB文件
	draw = image.copy()
	draw = cv2.cvtColor(draw, cv2.COLOR_BGR2RGB)
	# 圖像預處理
	image = preprocess_image(image)
	image, scale = resize_image(image)
	# 模型預測
	start = time.time()
	boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0))
	print("processing time: ", time.time() - start)
	# 矯正比例
	boxes /= scale
	# 目標檢測可視化展示
	for box, score, label in zip(boxes[0], scores[0], labels[0]):
		# 設置預測得分最低閾值
		if score < 0.75:
			break
		color = label_color(label)
		b = box.astype(int)
		draw_box(draw, b, color=color)
		caption = "{} {:.3f}".format(labels_to_names[label], score)
		draw_caption(draw, b, caption)
	#圖片展示
	plt.figure(figsize=(15, 15))
	plt.axis('off')
	plt.imshow(draw)
	plt.savefig(save_path+image_path,format='JPG',transparent=True,pad_inches=0,dpi=300,bbox_inches='tight')
	#plt.show()

結果挺好的,檢測了100張,沒有錯誤,速度也還不錯,大概兩個小時

6.評測

loss可視化

tensorboard --logdir='/home/zbb/keras-retinanet/CSV/logs' 

對應是文件夾,結果如下:

ap,precision-recall:

沒有找到,有會的可以博客下面留言交流

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章