基於TensorFlow的SSD車輛檢測-1

此係列博客是用來學習Tensorflow和Python的,由於是新手上車,如有錯誤之處希望大家不吝指出。

整個項目可以從百度雲下載:鏈接:https://pan.baidu.com/s/1f2JPJpE7m5M2kSifMP0-Lw 密碼:9p8v

一. 訓練數據準備

在訓練數據準備環節,主要包含下面三塊內容:

  • 怎樣解析用於車輛檢測訓練的KITTI數據集
  • 怎樣進行數據擴張來增大訓練數據的多樣性
  • 怎樣在訓練階段爲模型供給batch訓練數據

1. 讀取KITTI數據集

首先到KITTI官網http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=2d下載車輛檢測數據集。

具體地,只用下載下面3個壓縮包:(需要提供郵箱以獲得下載鏈接)

KITTI數據集採用了一張圖片對應一個標註文件的形式,其中標註文件是TXT格式,內容爲N行15列,每一列都使用空格隔開。這15列的內容是:

列號 名稱 描述
1 類別 目標類別,共8類:’Car’, ‘Van’, ‘Truck’,’Pedestrian’, ‘Person_sitting’, ‘Cyclist’, ‘Tram’, ‘Misc’ 或者 ‘DontCare’
2 是否有截斷 指目標是否超出圖像邊界,0: (non-truncated), 1: (truncated)
3 遮擋情況 0 = fully visible, 1 = partly occluded 2 = largely occluded, 3 = unknown
4 目標觀測角度 範圍[-pi..pi]
5-8 目標bbox 座標從0開始,[left, top, right, bottom]
9-11 3D維度 3D object dimensions: height, width, length (in meters)
12-14 3D空間座標 D object location x,y,z in camera coordinates (in meters)
15 Y軸旋轉角 Rotation ry around Y-axis in camera coordinates [-pi..pi]
16 置信度得分 僅用於Test,浮點數,用於繪製p/r曲線

備註:‘DontCare’表示忽略的未標記區域,這可能是因爲超出了激光掃描儀的工作範圍。測試時,位於該部分區域的結果會自動被忽略。訓練時可以同樣將此部分忽略,防止在此區域不斷地引起Hard Mining操作。

由於這裏只進行車輛檢測,因此標註信息中我們暫時只關注類別和BBox信息。此外,將’Car’, ‘Van’, ‘Truck’這3類合併爲正樣本目標,其餘區域作爲背景區域。

首先,我們需要批量的讀取每一個標註文件:

# readKITTI.py 用於解析KITTI數據集

import os

# 獲取指定後綴名的文件列表
def get_filelist(path,ext):
    # 獲取某個文件夾下的所有文件
    filelist_temp  = os.listdir(path)
    filelist = []
    # 通過比較後綴,選中所有TXT標註文件
    for i in filelist_temp:
        if os.path.splitext(i)[1] == ext:
            filelist.append(os.path.splitext(i)[0])
    return filelist

# 解析標註文件並返回目標的bounding box信息,維度Nx4
def get_bbox(filename):
    bbox = []
    # 判斷文件是否存在
    if os.path.exists(filename):
        with open(filename) as fi:
            label_data = fi.readlines()
        # 依次讀取每一行標註信息
        for l in label_data:
            data = l.split()
            # 如果存在車輛目標則記錄bounding box
            if data[0] in ['Van','Car','Truck']:
                bbox.append((float(data[4]),float(data[5]),
                    float(data[6]),float(data[7])))
    return bbox

# 批量獲取標註文件的bounding box信息
def get_bboxlist(rootpath,imagelist):
    bboxlist = []
    for i in imagelist:
        bboxlist.append(get_bbox(rootpath + i +'.txt'))
    return bboxlist

通過調用上述函數,我們便可以讀取KITTI數據集爲我們需要的形式:

import readKITTI

IMAGE_DIR = './image/training/image_2/'
LABEL_DIR = './label/training/label_2/'

imagelist = readKITTI.get_filelist(IMAGE_DIR,'.png')
bboxlist  = readKITTI.get_bboxlist(LABEL_DIR,imagelist)

2. 數據擴張

在深度學習模型的訓練過程中,數據擴張(Data Augmentation)通常都會被使用。其中,隨機縮放、所及裁剪、隨機翻轉應當是使用最廣泛的且行之有效的手段。(至於對比度調整、顏色調整、PCA這些東西,還真不好說。)

對於目標檢測而言,相當重要的一點是:對圖像進行調整的同時,也要保證目標bounding box的有效性與正確性。

縮放

爲了後續模型訓練的時候可以使用Batch,通常我們會將輸入圖像固定到統一尺寸,因此圖像resize並調整顏色統一是必不可少的。

# imAugment.py 提供一些用於數據擴張的函數

import cv2

# 將圖像按照指定尺寸進行縮放,同時處理boundingbox以及顏色信息
def imresize(in_img,in_bbox,out_w,out_h,is_color = True):
    # 判斷是否是字符串
    if isinstance(in_img,str):
        in_img = cv2.imread(in_img)
    # 獲取圖像寬度與高度
    height, width = in_img.shape[:2]
    out_img = cv2.resize(in_img,(out_w, out_h))
    # 調整圖像顏色
    if is_color == True and in_img.ndim == 2 :
        out_img = cv2.cvtColor(out_img, cv2.COLOR_GRAY2BGR)
    elif is_color == False and in_img.ndim == 3 :
        out_img = cv2.cvtColor(out_img, cv2.COLOR_BGR2GRAY)
    # 調整bounding box
    s_h = out_h / height
    s_w = out_w / width
    out_bbox = []
    for i in in_bbox:
        out_bbox.append((i[0]*s_w, i[1]*s_h, i[2]*s_w, i[3]*s_h))
    return out_img, out_bbox

水平翻轉

對於車輛檢測,垂直翻轉沒有必要,我們這裏只進行水平翻轉,並對應的翻轉bounding box。

# imAugment.py 提供一些用於數據擴張的函數

# 將圖像進行水平翻轉,同時處理boundingbox
def immirror(in_img,in_bbox):
    # 判斷是否是字符串
    if isinstance(in_img,str):
        in_img = cv2.imread(in_img)
    # 圖像水平翻轉
    out_img = cv2.flip(in_img,1)
    # 獲取圖像寬度
    width = out_img.shape[1]
    # 重新調整目標在翻轉後圖像上的位置
    out_bbox = []
    for i in in_bbox:
        out_bbox.append((width - i[0], i[1], width-i[2], i[3]))
    return out_img, out_bbox

隨機裁剪

隨機裁剪其實有很多約束和注意事項,主要有下面幾點:

  • 需要指定最小裁剪塊的大小。否則如果裁剪塊過小,則不適用於訓練。
  • 過小的圖像不應當再被裁剪
  • 由於我們無法準確的形容一個被裁剪掉一塊的目標是否還是一個有效的可被識別的目標,因此我們的裁剪區域應當包含所有目標的bounding box。
# imAugment.py 提供一些用於數據擴張的函數

import random
# 將圖像進行隨機crop,同時處理boundingbox, min_wh爲crop塊的最小寬高
def imcrop(in_img,in_bbox,min_hw):
    # 判斷是否是字符串
    if isinstance(in_img,str):
        in_img = cv2.imread(in_img)
    # 獲取圖像寬度與高度
    height, width = in_img.shape[:2]
    # 如果圖像過小,則放棄crop
    if height <= min_hw and width <= min_hw:
        return in_img, in_bbox
    # 爲了防止有效目標被crop截斷,crop範圍應包含所有目標
    # 下面尋找包含所有目標的最小矩形
    min_x1, min_y1, min_x2, min_y2 = width-1, height-1, 0, 0
    for i in in_bbox:
        min_x1 = min(min_x1,int(i[0]))
        min_y1 = min(min_y1,int(i[1]))
        min_x2 = max(min_x2,int(i[2]))
        min_y2 = max(min_y2,int(i[3]))

    # 根據最小包圍框,再隨機生成一個矩形框,並防止框超出圖像範圍
    rand_x1, rand_y1, rand_x2, rand_y2 = 0, 0, width, height
    if min_x1 <= 1:
        rand_x1 = 0
    else:
        rand_x1 = random.randint(0,min(min_x1,max(width - min_hw,1)))
    if min_y1 <= 1:
        rand_y1 = 0
    else:
        rand_y1 = random.randint(0,min(min_y1,max(height - min_hw,1)))
    if min_x2 >= width or rand_x1 + min_hw >= width:
        rand_x2 = width
    else:
        rand_x2 = random.randint(max(rand_x1+min_hw,min_x2),width)
    if min_y2 >= height or rand_y1 + min_hw >= height:
        rand_y2 = height
    else:
        rand_y2 = random.randint(max(rand_y1+min_hw,min_y2),height)

    # crop圖像
    out_img = in_img[rand_y1:rand_y2-1,rand_x1:rand_x2-1]
    # 處理bounding box
    out_bbox = []
    for i in in_bbox:
        out_bbox.append((i[0]-rand_x1,i[1]-rand_y1,i[2]-rand_x1,i[3]-rand_y1))
    return out_img, out_bbox

下面給出效果圖:(最上面的是原圖,下面依次是水平翻轉,縮放和隨機裁剪)

這裏寫圖片描述

3. Batch生成

訓練階段我們需要生成一個個batch用於訓練,一般需要的參數設置包括:batchsize、訓練圖片的大小、顏色、是否shuffle數據、是否隨機crop等。基於此,下面給出一個供給batch的代碼:

# genBatch.py 用於訓練階段供給訓練數據

# coding=utf-8
import random
import readKITTI
import imAugment
import cv2


class genBatch:
    image_dir, label_dir = [], []
    image_list, bbox_list = [], []
    initOK = False

    # 初始化讀取數據
    def initdata(self, imagedir, labeldir):
        self.image_dir, self.label_dir = imagedir, labeldir
        self.image_list = readKITTI.get_filelist(imagedir,'.png')
        self.bbox_list  = readKITTI.get_bboxlist(labeldir,self.image_list)
        # 如果數據不爲空且圖片和label數量相匹配
        if len(self.image_list) > 0 and len(self.image_list) == len(self.bbox_list):
           self.initOK = True
        else:
            print("The amount of images is %d, while the amount of"
                    "corresponding label is %d"%(len(self.image_list),len(self.bbox_list)))
            self.initOK = False
        return self.initOK

    readPos = 0

    # 生成一個新的batch
    def genbatch(self,batchsize,newh,neww,iscolor=True,isshuffle=False,
                mirrorratio=0.0, cropratio=0.0):
        if self.initOK == False:
            print("The initdata() function must be successfully called first.")
            return []
        batch_data, batch_bbox = [], []
        for i in range(batchsize):
            # 當數據遍歷一遍時
            if self.readPos >= len(self.image_list)-1:
                self.readPos = 0
                if isshuffle == True:
                    # 指定同一隨機種子,保證圖片和label採用同樣的亂序
                    r_seed = random.random()
                    random.seed(r_seed)
                    random.shuffle(self.image_list)
                    random.seed(r_seed)
                    random.shuffle(self.bbox_list)
            img = cv2.imread(self.image_dir + self.image_list[self.readPos] + '.png')
            bbox = self.bbox_list[self.readPos]
            self.readPos += 1

            # 按照指定概率進行crop,切記裁剪應當發生在resize之前
            if cropratio > 0 and random.random() < cropratio:
                img, bbox = imAugment.imcrop(img,bbox,min(neww,newh))

            # 調整圖像大小及顏色
            img, bbox = imAugment.imresize(img,bbox,neww,newh,iscolor)

            # 按照指定概率進行隨機鏡像
            if mirrorratio > 0 and random.random() < mirrorratio:
                img, bbox = imAugment.immirror(img,bbox)

            batch_data.append(img)
            batch_bbox.append(bbox)
        return batch_data, batch_bbox
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章