使用Pytorch進行Faster R-CNN目標檢測

 

在這篇文章中,我們將通過使用pytorch進行faster R-CNN目標檢測。我們將學習目標檢測從R-CNN到 fast R-CNN到 faster R-CNN的演變過程。

1、圖像分類VS目標檢測

圖像分類用於將類標籤分配給輸入圖像。例如,給定貓的輸入圖像,圖像分類算法的輸出是標籤“貓”。 在目標檢測中,我們不僅對輸入圖像中的對象感興趣,而且對它們的位置也感興趣。 下圖說明了圖像分類和目標檢測的區別。

1.1.圖像分類與目標檢測:使用哪一個?

通常,圖像分類用於處理只包含一個對象的圖像。可能有多個類(例如。貓、狗等),但是通常,圖像中只有一個類的實例。 在輸入圖像中有多個對象的大多數應用程序中,我們需要找到對象的位置,然後對它們進行分類。在這種情況下,我們使用目標檢測算法。 目標檢測比圖像分類在速度上慢數百倍,因此,在應用中,如果對象在圖像中的位置並不重要,那我們最好還是使用圖像分類。

2.目標檢測

我們可以把目標檢測分成兩步

1、查找包含對象的包圍框,使每個包圍框只有一個對象。

2、對每個包圍框中的圖像進行分類,併爲其確定一個標籤。

在接下來的幾節中,我們將介紹faster R-CNN目標檢測是如何一步步發展出來的。

2.1滑動窗口法

大多數經典的用於目標檢測的計算機視覺技術,如HAAR(哈爾)級聯和HOG(方向梯度統計直方圖)+SVM,都使用滑動窗口方法來檢測目標。 在這種方法中,滑動窗口在圖像上移動,滑動窗口中的所有像素都被裁剪出來併發送到圖像分類器。 如果圖像分類器識別到已知對象,則存儲邊界框和類標籤。否則,窗口繼續滑動並將被評估。 滑動窗口方法非常浪費計算資源,因爲爲了檢測輸入圖像中的對象,需要在圖像中的每個像素上評估不同尺度和縱橫比的滑動窗口。 因此,只有當我們檢測具有固定縱橫比的單個對象類時,才使用滑動窗口。例如,OpenCV中基於HOG + SVM或HAAR的人臉檢測器使用滑動窗口方法。在人臉檢測器中,複雜性是可控的,因爲只有方形包圍框在不同的尺度上被檢測。

2.2.R-CNN目標檢測器

基於卷積神經網絡(CNN)的圖像分類器在2012年贏得ImageNet大規模視覺識別挑戰(ILSVRC)後變得流行起來。 由於每個目標檢測器的核心都有一個圖像分類器,基於CNN的目標檢測器的發明變得不可避免。

有兩個問題需要解決

1、與傳統的HOG + SVM或HAAR級聯等技術相比,基於CNN的圖像分類器非常耗費計算資源。

2、計算機視覺社區越來越雄心勃勃。人們想要建立一個多類目標檢測器,除了不同的尺度外,還可以處理不同的縱橫比。

因此,基於滑動窗口的目標檢測方法被淘汰了。它太費資源了。

研究人員開始研究一種新的思路,即訓練一種機器學習模型,該模型可以提出包含對象的包圍框的位置。這些包圍框被稱爲區域提案或目標提案。

區域提案僅僅是一個列表,元素是包含對象的包圍框,而這些包圍框包含對象的概率也不是太高,只是具有一定概率。它不知道也不在乎邊界框中包含了哪種對象。

區域提案算法在不同的位置、尺度和縱橫比上輸出包含幾百個包圍框的列表。這些包圍框中的大多數不包含任何對象。 爲什麼區域提案仍然有用? 在區域提案算法提出的幾百個邊界框中使用圖像分類器,比在滑動窗口方法提出的幾十萬個甚至數百萬個邊界框中使用圖像分類器要高效的多。 使用區域提案的第一種方法之一是RossGirschick等人,稱爲R-CNN(具有CNN特徵的區域的簡稱)。

他們使用一種名爲選擇搜索的算法來檢測出2000個區域提案,並在這些2000個包圍框上運行了一個基於CNN + SVM的圖像分類器。 當時R-CNN的準確性是最先進的,但速度仍然很慢(GPU上的每張圖像18-20秒)

 2.3 fast R-CNN目標探測器

在R-CNN中,每個邊界框由圖像分類器獨立分類,有2000個區域提案,圖像分類器計算了每個區域提案的特徵圖。這個過程很耗時。 在Ross Girshick的後續工作中,他提出了一種稱爲快速R-CNN的方法,它顯著地加快了目標檢測的速度。 其想法是爲整個圖像計算一個單一的特徵圖,而不是爲2000個區域提案計算2000個特徵圖。對於每個區域提案,感興趣區域(ROI)池化層從特徵圖中提取固定長度的特徵向量。然後,每個特徵向量被用於兩個目的

1、將區域分類爲其中一個類(例如。狗,貓,背景)。

2、使用邊界框迴歸器提高原始邊界框的精度。

2.4 faster R-CNN目標檢測器

在 fast R-CNN中,即使對2000個區域提案進行分類的計算是共享的,但生成區域提案的算法部分與執行圖像分類的部分不共享任何計算。 在被稱爲FasterR-CNN的後續工作中,主要的見解是這兩個部分-計算區域提案和圖像分類-可以使用相同的特徵圖,從而分擔計算負荷。 利用卷積神經網絡生成圖像特徵圖,同時用於訓練區域提案網絡和圖像分類器。由於這種共享計算,對象檢測的速度有了顯著的提高。

3.用PyTorch進行目標檢測 [代碼部分]

在本節中,我們將學習如何使用Pytorch中的 faster R-CNN目標檢測器。我們將使用torchvision中的預訓練模型。在PyTorch的所有預先訓練的模型都可以在torchvision.models中找到。

3.1.輸入和輸出

我們將使用預訓練模型faster R-CNN ResNet-50,該模型期望輸入圖像張量的形式[n,c,h,w],要求最小尺寸爲800px。

n是圖像的數目

c爲通道數,對於RGB圖像是3

h是圖像的高度

w是圖像的寬度

模型返回結果

所有預測的包圍框組成的二維列表,維度爲(N,4),其中N是由模型預測圖像中的包圍框個數,包圍框的形狀爲[x0,y0,x1,y1],也就是包圍框的左上角座標和右下角座標

所有預測類的標籤。

每個預測標籤的置信度。

3.2.預訓模式

下載預訓練模型,Resnet50 Faster R-CNN,帶有訓練好的權重參數。

from PIL import Image
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as T
import torchvision
import torch
import numpy as np
import cv2

# get the pretrained model from torchvision.models
# Note: pretrained=True will get the pretrained weights for the model.
# model.eval() to use the model for inference
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()

定義Pytorch官方文檔給出的類名稱

COCO_INSTANCE_CATEGORY_NAMES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

 我們可以在列表中看到一些N/A,因爲在後面的論文中刪除了一些類。我們將按照PyTorch給出的列表進行操作。

3.3.模型預測

讓我們定義一個函數來獲得圖像路徑,並通過模型得到圖像的預測。

def get_prediction(img_path, threshold):
  img = Image.open(img_path) # Load the image
  transform = T.Compose([T.ToTensor()]) # Defing PyTorch Transform
  img = transform(img) # Apply the transform to the image
  pred = model([img]) # Pass the image to the model
  print('pred')
  print(pred)
  pred_class = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(pred[0]['labels'].numpy())] # Get the Prediction Score
  print("original pred_class")
  print(pred_class)
  pred_boxes = [[(i[0], i[1]), (i[2], i[3])] for i in list(pred[0]['boxes'].detach().numpy())] # Bounding boxes
  print("original pred_boxes")
  print(pred_boxes)
  pred_score = list(pred[0]['scores'].detach().numpy())
  print("orignal score")
  print(pred_score)
  pred_t = [pred_score.index(x) for x in pred_score if x > threshold][-1] # Get list of index with score greater than threshold.
  pred_boxes = pred_boxes[:pred_t+1]
  pred_class = pred_class[:pred_t+1]
  print(pred_t)
  print(pred_boxes)
  print(pred_class)
  return pred_boxes, pred_class

從圖像路徑中獲取圖像

使用PyTorch變換將圖像轉換爲圖像張量

通過在模型傳遞圖像以得到預測結果

得到類、包圍框座標,但只選擇預測分數>閾值的結果。

函數處理過程中的階段結果都被打印出來用於分析實驗結果

 3.4.目標檢測流程 

接下來,我們將定義一個函數來獲取圖像路徑並獲得輸出圖像。

def object_detection_api(img_path, threshold=0.5, rect_th=3, text_size=3, text_th=3):
 
  boxes, pred_cls = get_prediction(img_path, threshold) # Get predictions
  img = cv2.imread(img_path) # Read image with cv2
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Convert to RGB
  for i in range(len(boxes)):
    cv2.rectangle(img, boxes[i][0], boxes[i][1],color=(0, 255, 0), thickness=rect_th) # Draw Rectangle with the coordinates
    cv2.putText(img,pred_cls[i], boxes[i][0],  cv2.FONT_HERSHEY_SIMPLEX, text_size, (0,255,0),thickness=text_th) # Write the prediction class
  plt.figure(figsize=(20,30)) # display the output image
  plt.imshow(img)
  plt.xticks([])
  plt.yticks([])
  plt.show()

用get_prediction函數進行預測

對於每個預測, 使用opencv 繪製包圍框並添加標籤文本

最終圖像顯示

3.5.推理

現在,讓我們使用以上代碼來檢測一些圖像中的目標。預訓練模型在CPU中需要大約8秒的推理,在NVIDIA GTX1080 Ti GPU中需要0.15秒。本例中使用cpu進行推理。

示例1:

在當前目錄下載圖片:

wget https://www.wsha.org/wp-content/uploads/banner-diverse-group-of-people-2.jpg -O people.jpg

運行python代碼:

object_detection_api('./people.jpg', threshold=0.8)

最終圖片:

處理過程的打印信息如下,可見,模型輸出爲一個字典,共包含3個元素,分別是boxes,labels和scores。而且已經按照scores的得分進行排序。所以閾值的選擇變的很關鍵,閾值過高會遺漏目標,過低則會增加誤報。

pred
[{'boxes': tensor([[0.0000e+00, 3.1965e+02, 4.4152e+02, 1.3233e+03],
        [7.6530e+02, 3.5194e+02, 1.1937e+03, 1.3123e+03],
        [1.1569e+03, 3.9462e+02, 1.5692e+03, 1.3222e+03],
        [1.5193e+03, 4.2957e+02, 1.9924e+03, 1.3306e+03],
        [3.6411e+02, 3.8761e+02, 7.7179e+02, 1.3062e+03],
        [6.1343e+02, 2.6801e+02, 9.1222e+02, 1.2417e+03],
        [1.4390e+03, 3.3610e+02, 1.6700e+03, 7.9422e+02],
        [1.7240e+03, 2.8933e+02, 1.9978e+03, 8.8992e+02],
        [1.1548e+03, 1.8348e+02, 1.4562e+03, 7.0534e+02],
        [1.6003e+03, 1.9746e+02, 1.7637e+03, 6.3541e+02],
        [8.8223e+02, 1.9024e+02, 1.2439e+03, 7.4672e+02],
        [5.4620e+02, 2.4167e+02, 7.1724e+02, 5.4152e+02],
        [2.3673e+02, 1.0112e+02, 4.8953e+02, 6.6732e+02],
        [8.3154e+02, 2.0397e+02, 1.1441e+03, 6.0038e+02],
        [2.1523e+02, 1.5433e+02, 4.8707e+02, 1.1148e+03],
        [1.1424e+03, 1.9198e+02, 1.5366e+03, 9.7392e+02],
        [0.0000e+00, 6.2382e+02, 1.6149e+02, 1.2446e+03],
        [1.5768e+03, 2.0830e+02, 1.8524e+03, 7.4405e+02],
        [2.5804e+02, 3.3354e+02, 6.2200e+02, 1.3077e+03],
        [8.9626e+02, 3.5898e+02, 1.3640e+03, 1.3158e+03],
        [1.0498e+03, 1.0657e+03, 1.0759e+03, 1.1004e+03],
        [5.3046e+00, 4.1446e+02, 2.1460e+02, 1.3151e+03],
        [1.5564e+03, 2.1095e+02, 1.9997e+03, 9.6501e+02],
        [4.5241e+01, 9.1016e+02, 1.5961e+02, 1.0559e+03],
        [1.4584e+03, 3.3478e+02, 1.7331e+03, 1.2968e+03],
        [1.8249e+03, 2.9717e+02, 2.0000e+03, 1.2894e+03],
        [1.0173e+03, 2.9484e+02, 1.2510e+03, 7.7417e+02],
        [1.7593e+02, 5.9531e+02, 3.9948e+02, 9.5223e+02],
        [9.5730e+02, 1.7660e+02, 1.6549e+03, 7.6932e+02],
        [8.1830e+00, 1.1709e+03, 1.5495e+02, 1.3270e+03],
        [1.4883e+03, 2.6912e+02, 1.7646e+03, 7.7534e+02],
        [5.5098e+01, 9.1188e+02, 8.6513e+01, 9.8316e+02],
        [3.7265e+02, 7.5000e+02, 4.7406e+02, 1.3200e+03],
        [0.0000e+00, 5.6960e+02, 1.6152e+02, 8.6986e+02],
        [1.0386e+03, 1.0430e+03, 1.0768e+03, 1.0974e+03],
        [1.2317e+03, 3.0593e+02, 1.6184e+03, 8.5238e+02],
        [3.6448e-01, 8.9336e+02, 8.7396e+01, 9.7809e+02],
        [1.1957e+00, 7.1775e+02, 1.5803e+02, 1.1761e+03],
        [6.7784e+01, 1.7359e+02, 4.4879e+02, 7.4864e+02]],
       grad_fn=<StackBackward>), 'labels': tensor([ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1, 31,  1,
         1,  1, 85,  1,  1,  3,  1,  1,  1, 32,  1,  1,  1, 14,  1, 31, 85,  1,
         3, 28,  1]), 'scores': tensor([0.9986, 0.9983, 0.9982, 0.9971, 0.9964, 0.9874, 0.9854, 0.9853, 0.9773,
        0.9734, 0.9606, 0.8864, 0.8327, 0.7735, 0.3906, 0.3572, 0.3154, 0.2906,
        0.2022, 0.1946, 0.1904, 0.1621, 0.1552, 0.1310, 0.1213, 0.1161, 0.1100,
        0.1049, 0.0986, 0.0899, 0.0815, 0.0799, 0.0691, 0.0663, 0.0660, 0.0652,
        0.0590, 0.0584, 0.0550], grad_fn=<IndexBackward>)}]
original pred_class
['person', 'person', 'person', 'person', 'person', 'person', 'person', 'person', 'person', 'person', 'person', 'person', 'person', 'person', 'person', 'person', 'handbag', 'person', 'person', 'person', 'clock', 'person', 'person', 'car', 'person', 'person', 'person', 'tie', 'person', 'person', 'person', 'parking meter', 'person', 'handbag', 'clock', 'person', 'car', 'umbrella', 'person']
original pred_boxes
[[(0.0, 319.6455), (441.5155, 1323.2559)], [(765.3034, 351.93936), (1193.7164, 1312.2843)], [(1156.9446, 394.623), (1569.2057, 1322.1747)], [(1519.3403, 429.56564), (1992.429, 1330.6052)], [(364.10757, 387.61426), (771.79016, 1306.173)], [(613.42804, 268.0113), (912.2196, 1241.745)], [(1439.0381, 336.10446), (1669.9921, 794.2212)], [(1723.9878, 289.32892), (1997.8444, 889.91785)], [(1154.8141, 183.47633), (1456.161, 705.3361)], [(1600.269, 197.45708), (1763.7482, 635.4068)], [(882.2303, 190.24492), (1243.9407, 746.71704)], [(546.1972, 241.6657), (717.2413, 541.52277)], [(236.73463, 101.11602), (489.53278, 667.32404)], [(831.5391, 203.96678), (1144.0723, 600.3847)], [(215.23401, 154.3328), (487.06903, 1114.8004)], [(1142.3718, 191.98454), (1536.619, 973.9194)], [(0.0, 623.81696), (161.48994, 1244.6355)], [(1576.7936, 208.30457), (1852.4454, 744.0468)], [(258.0397, 333.54254), (622.0048, 1307.7097)], [(896.257, 358.9788), (1364.0449, 1315.7703)], [(1049.787, 1065.661), (1075.9006, 1100.3923)], [(5.3046417, 414.46304), (214.60019, 1315.0636)], [(1556.4305, 210.94743), (1999.7416, 965.0133)], [(45.241035, 910.1611), (159.61392, 1055.8945)], [(1458.4309, 334.77945), (1733.1425, 1296.8483)], [(1824.8823, 297.17184), (2000.0, 1289.3732)], [(1017.25616, 294.84235), (1251.0472, 774.17303)], [(175.93126, 595.3143), (399.48138, 952.2279)], [(957.2965, 176.59543), (1654.9402, 769.3168)], [(8.183002, 1170.8748), (154.95354, 1327.0237)], [(1488.3322, 269.119), (1764.643, 775.338)], [(55.097706, 911.8797), (86.512726, 983.1622)], [(372.65338, 750.0016), (474.0599, 1319.9922)], [(0.0, 569.6015), (161.51752, 869.8641)], [(1038.5804, 1042.9966), (1076.7808, 1097.398)], [(1231.6813, 305.93085), (1618.3899, 852.3754)], [(0.3644816, 893.3566), (87.39571, 978.09436)], [(1.1957486, 717.7485), (158.03418, 1176.0779)], [(67.78371, 173.59235), (448.7855, 748.63654)]]
orignal score
[0.9986298, 0.9983411, 0.998161, 0.997095, 0.9964186, 0.98744327, 0.9854257, 0.98530376, 0.97725266, 0.97337264, 0.96058786, 0.88644713, 0.83272547, 0.7734675, 0.39064288, 0.35724017, 0.3154308, 0.29059145, 0.20217809, 0.19457912, 0.19040918, 0.16206688, 0.15520865, 0.13098373, 0.121276304, 0.116065815, 0.11003116, 0.10491751, 0.09857031, 0.089911185, 0.0815428, 0.07988655, 0.06910095, 0.06626528, 0.06596641, 0.0652364, 0.058985297, 0.05841243, 0.054988183]
12
[[(0.0, 319.6455), (441.5155, 1323.2559)], [(765.3034, 351.93936), (1193.7164, 1312.2843)], [(1156.9446, 394.623), (1569.2057, 1322.1747)], [(1519.3403, 429.56564), (1992.429, 1330.6052)], [(364.10757, 387.61426), (771.79016, 1306.173)], [(613.42804, 268.0113), (912.2196, 1241.745)], [(1439.0381, 336.10446), (1669.9921, 794.2212)], [(1723.9878, 289.32892), (1997.8444, 889.91785)], [(1154.8141, 183.47633), (1456.161, 705.3361)], [(1600.269, 197.45708), (1763.7482, 635.4068)], [(882.2303, 190.24492), (1243.9407, 746.71704)], [(546.1972, 241.6657), (717.2413, 541.52277)], [(236.73463, 101.11602), (489.53278, 667.32404)]]
['person', 'person', 'person', 'person', 'person', 'person', 'person', 'person', 'person', 'person', 'person', 'person', 'person']

示例2:

 下載圖片:

wget https://hips.hearstapps.com/hmg-prod.s3.amazonaws.com/images/10best-cars-group-cropped-1542126037.jpg -O car.jpg

進行預測:

object_detection_api('./car.jpg', rect_th=6, text_th=5, text_size=5)

最終圖像:

示例3:

下載圖片:

wget https://cdn.pixabay.com/photo/2013/07/05/01/08/traffic-143391_960_720.jpg -O traffic_scene.jpg

進行預測:

object_detection_api('./traffic_scene.jpg', rect_th=2, text_th=1, text_size=1)

 最終結果:

示例4:

下載圖像:

wget https://images.unsplash.com/photo-1458169495136-854e4c39548a -O girl_cars.jpg

 進行預測:

object_detection_api('./girl_cars.jpg', rect_th=15, text_th=7, text_size=5, threshold=0.8)

最終結果:

 

4、模型在CPU和GPU中推理時間對比

import time

def check_inference_time(image_path, gpu=False):
  model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
  model.eval()
  img = Image.open(image_path)
  transform = T.Compose([T.ToTensor()])
  img = transform(img)
  if gpu:
    model.cuda()
    img = img.cuda()
  else:
    model.cpu()
    img = img.cpu()
  start_time = time.time()
  pred = model([img])
  end_time = time.time()
  return end_time-start_time

cpu_time = sum([check_inference_time('./girl_cars.jpg', gpu=False) for _ in range(10)])/10.0
gpu_time = sum([check_inference_time('./girl_cars.jpg', gpu=True) for _ in range(10)])/10.0


print('\n\nAverage Time take by the model with GPU = {}s\nAverage Time take by the model with CPU = {}s'.format(gpu_time, cpu_time))

輸出結果:

Average Time take by the model with GPU = 0.15356571674346925s
Average Time take by the model with CPU = 8.458594107627869s

參考鏈接:https://www.learnopencv.com/faster-r-cnn-object-detection-with-pytorch/

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章