Faster RCNN系列介紹

RCNN網絡

RCNN網絡是Ross B. Girshick大神2014年提出的,是第一個基於深度學習的目標檢測算法,是基於深度特徵的。它跟傳統的目標識別不同的就是使用深度學習的特徵代替了傳統的基於低層次的顏色、紋理的特徵提取。

因爲RCNN並沒有對整個傳統目標檢測的框架進行改進或者優化,因此RCNN網絡依然存在傳統目標檢測算法所存在的問題,如檢測速度或者效率低下,檢測精度雖然相對於原先的目標檢測算法有了一定的改善,但是依然難以滿足實際的需要。這也是後續算法重要的改進方向。

RCNN網絡是如何使用深度學習特徵的。

  1. RCNN在使用深度學習特徵的時候並沒有直接使用卷積神經網去進行訓練,而是使用了預訓練模型AlexNet來訓練分類模型,而這個模型是在ImageNet上訓練好的模型。
  2. 在RCNN這個模型的分類任務上進行一個模型的fine-tuning(遷移)。在這個遷移的過程中主要將FC層去掉。因爲AlexNet是面向1000個分類任務進行分類的,而在目標檢測或者VOC數據集上,目標數爲20,加上背景是21類。因此需要將FC層去掉,因爲FC層的參數不同,只保留了主幹網絡層,也就是AlexNet卷積層的這些特徵,利用這樣的特徵來進行模型進一步的訓練。針對於CNN提取之後的特徵,RCNN網絡依然採用SVM分類器以及Bounding box迴歸的方式來得到最終的候選區域的類別和定位的位置信息。
  3. 在提取候選區域的時候採用選擇性搜索策略而不是採用滑動窗口的策略。在採用滑動窗口策略一方面要考慮滑動窗口步長以及滑動窗口的大小,此時如果採用滑動窗口的策略,往往意味着要計算非常多的候選區域,而計算候選區域數量的增大,則直接帶來了計算性能上的下降。整個目標檢測的流程耗時會變得非常大。RCNN採用了選擇性搜索的策略來提取候選框。
    1. 生成區域集R,具體參考論文《Efficient Graph-Based Image Segmentation》,採用區域分割的模型,進行區域分割之後就能得到一系列分割好的區域,這些區域往往是圖像中某些可能的物體。
    2. 計算區域集R裏每個相鄰區域的相似度S={s1,s2,...}
    3. 找出相似度最高的兩個區域,將其合併爲新集,添加進R,減少生成的區域的個數。如何判斷兩個區域的相似度,通常也是採用顏色、紋理這樣一些低層次的計算機視覺中的一些特徵。
    4. 從S中移除所有與3中有關的子集
    5. 計算新集與所有子集的相似度
    6. 跳至3,直至S爲空,沒有相似區域了,就終止整個迭代的過程。最終得到的子集就是我們需要提取的候選框的集合。
  4. 訓練SVM分類器:每個類別對應一個SVM,有關SVM的內容可以參考機器學習算法整理(三) 中的支撐向量機SVM
  5. 迴歸器精修候選框位置:利用線性迴歸模型判定框的準確度

對於一個候選區域,可能會存在上圖中給出的一些例子。線性迴歸模型將會對dx、dy、dw、dh這四個值進行迴歸。這四個值,我們可以將其理解爲差分的值,差分值表示了當前的候選框的偏移。如果候選框能夠完整且非常規整的檢測到候選目標,那麼此時所有的偏移值都是0。如果候選框往一個方面偏了,那麼此時dx就會預測出來相應的偏差的值,比如上圖中中間的貓預測出來dx值爲0.25。如果候選框大了,檢測的目標小,就意味着預測的框的dw或者dh會存在一定的冗餘。比如上圖中第三幅圖中的貓的邊框寬度是明顯大於候選目標的,則dw預測出來的爲-0.125表示寬度要減少0.125,而高度是剛好的,則dh爲0。

最終通過RCNN就能夠得到一系列的候選框,再通過NMS算法來對候選框進行進一步的篩選和合並,最終得到目標檢測的輸出。

  • RCNN缺點
  1. 候選框選擇算法耗時嚴重,很難做到實時計算。
  2. 重疊區域特徵重複計算
  3. 分步驟進行,過程繁瑣。這裏我們最少要分別訓練三個部分——CNN,SVM,Reg。過程相對於一個端到端的網絡模型就會變得非常複雜。整個pipleline在搭建時也會存在一個非常複雜的過程,比如在訓練SVM的時候就需要專門去生成SVM的樣本,但是我們生成的樣本可能同我們利用CNN卷積特徵提取出來的樣本會存在一定的差異。這個時候就會存在一個分佈上的問題。我們在採用分步驟再採用獨立模型訓練的時候,將整個模型拼到一起的時候,可能就會存在性能上的一些損失。一個比較好的做法就是將整個pipleline搭起來,將CNN的輸出作爲SVM的輸入,將SVM的輸出作爲Reg迴歸網絡的輸入。此時我們就需要訓練三種不同的模型。這個過程相對於一個完整的端到端的模型而言就會比較複雜。而且我們需要構造相應的樣本,整個過程也會變得相應的繁瑣。

SPPNet

針對RCNN的缺點,SPPNet進行了一定程度的改進。這裏的改進主要體現在其中的一點上,就是如何對卷積特徵進行共享。在SPPNet中提出了一個金字塔池化層,利用這樣一個金字塔池化層就能夠完成卷積特徵的共享。我們來具體看一下這個金字塔池化層是如何實現的。

在上圖中表示了原始的RCNN在進行目標檢測時候,它所要完成的一個特徵提取表示的一個過程。首先我們將輸入圖像Image理解爲候選區域,針對輸入圖像進行候選區域的提取,這裏通常會採用Crop/warp操作。Crop就是對候選區域進行摳圖,然後將圖片resize到一個固定的尺寸。因爲採用了摳圖和resize的過程,因此我們獲取到的候選區域可能就會存在一些畸變和扭曲。接下來將resize過的固定尺寸的圖片輸入到卷積神經網中進行特徵的提取,最後通過一個FC層得到一個輸出的向量。這裏採用FC層就意味着卷積層輸出圖像的feature map必須保持一致,對於不同的候選區域我們需要保證它輸出的卷積層的feature map的大小保持一致。因爲這裏採用了同一個卷積神經網,就意味着卷積神經網的輸入的圖像的尺寸必須保持一致,因此必須採用resize的操作來保證輸入的圖像在同一個尺寸上。由於採用了resize,就會導致輸入的候選區域進行一定程度的扭曲和拉伸,這樣也會在一定程度上影響最終提取出來的特徵。

另外將候選區域的提取放在了圖像輸入的下一步,這裏通常針對不同的候選區域,會分別採用一次卷積來完成這樣一個特徵提取的過程,就意味着會有一些計算量的重複。這也是RCNN網絡所面臨的一個非常嚴重的問題,會存在計算量的浪費。

SPPNet針對於上圖的流程給出了一個優化。它的優化主要體現在了這樣的幾點上,一對輸入的圖像採用不同的尺寸來進行圖像的輸入。採用不同尺寸圖像的輸入就會得到不同尺寸的feature map大小的輸出。針對不同尺寸feature map大小的輸入,如何得到一個同樣大小的FC層大小的輸入,這裏就引入了一個spp層。

具體spp層的結構如下

上圖可以看到分別給出了三個不同的尺寸——16*256-d、4*256-d、256-d。對於這三種不同尺寸,我們分別提取不同維度的特徵。如何保證不同尺寸輸入到FC層之後,它們的大小依然是一致的呢?如果我們將卷積層的結果作爲FC層的輸入,就意味着輸入圖像尺寸大小不一樣的話,得到的三個不同尺度上的圖像依然會有一些區別。這樣的一些區別就意味着在輸入FC層的時候,參數量就會存在不同,就會產生錯誤。SPPNet在處理這個問題的時候,實際上將每一個卷積層的輸出固定的通過spp層之後固定的得到一個21維的特徵,這個21維實際上是針對於每一個feature map而言的,每一個channel通道而言的,具體卷積層輸出的特徵數爲21*channel數量。針對於其中的一個feature map,如上圖中的第一個16*256-d的特徵圖,我們平均的將它分成一個4*4的網格,這樣的4*4的網格,每一個網格都得到一個特徵點的輸出,無論它的尺寸多大,此時它的輸出都是16;上圖中的第二個4*256-d的特徵圖,我們爲了得到4個特徵點的話,就將feature map劃分成2*2的網格,每一個網格得到一個特徵點,這樣的feature map最終能得到4個特徵點,如果是256個通道的話,最終能得到4*256個特徵向量;上圖中的第三個256-d的特徵圖,它只有一個特徵向量,那我們就得到1*256維的特徵向量。通過這樣的一個spp層,我們就能忽略feature map的大小,得到固定尺寸的輸出,這樣固定長度的輸出就能夠作爲FC層的輸入。通過這樣一個spp層就能夠對不同候選區域的尺寸來進行整體的處理,得到一個固定的輸出。這個輸出在輸入到FC層之後,就能夠進行後續的全連接層的計算。

具體如何通過這樣一個區域來得到一個特徵值呢?實際上是通過一個池化操作來完成的,我們知道Pooling操作包括Max和Mean這樣的一些算子,spp層也是一個特殊的池化層,SPPNet就能夠對於不同大小候選框的輸入來得到一個固定的輸出用於後續FC層的計算。SPPNet有一個非常重要的操作——僅對原圖提取一次卷積特徵,這個操作爲我們節省了非常多的計算量。在RCNN網絡中,在卷積特徵是存在重複計算的問題,因此在SPPNet中就對這個問題就進行了改進。

從SPPNet的流程圖我們也可以看到,對SPPNet的輸入並不是原圖,而是經過卷積之後的feature map,它們可能是多個通道的feature map,另外這個feature map分別採用了不同尺寸的輸入,也就意味着我們會得到多個不同尺寸圖像輸入經過卷積之後得到的feature map,此時就和SPPNet的結構圖對應起來了,爲什麼通道數是256而不是3,是因爲經過卷積之後,它的channel的數量就會非常的高,這裏輸入的數量爲256。爲什麼會有3個不同的feature map呢,就對應到了CNN的不同尺寸圖像的輸入。這樣的過程SPPNet就能夠節省更多的計算量並且解決了原始的RCNN網絡中對於候選區域固定尺寸所帶來的圖形扭曲的問題。因此SPPNet相對RCNN能夠帶來更加識別的準確度以及更快的檢索速度。

Fast RCNN

Fast RCNN是RCNN的作者繼RCNN之後提出的一個新的模型。在Fast RCNN中不同於SPPNet,採用了ROI Pooling層來對候選區域進行Crop以及完成固定尺寸輸出。這裏的ROI Pooling我們可以將其理解成一個單層的SPPNet。在Fast RCNN中採用多任務網絡來解決分類和位置迴歸問題。也就是說在Fast RCNN中將不再採用SVM分類器以及線性迴歸模型來完成候選框類別的判定和位置的迴歸。這裏同樣採用一個神經網絡來完成分類和迴歸的任務。

借鑑SPPNet的優點,Fast RCNN同樣採用共享卷積的策略來節省卷積運算的時間,去掉重複的計算。Fast RCNN爲Faster RCNN的提出打下了基礎,提供了可能。

Fast RCNN網絡結構

Fast RCNN網絡相對於SPPNet來說依然包括了一個共享的卷積層,同SPPNet不同的地方在於採用了一個ROI Pooling層來得到FC層的固定輸入。最終FC層得到的結果分別用於Bounding box的迴歸以及候選區域類別的判定。相對於SPPNet來說,主要改進的點就在於一個是ROI Pooling層的提出,另外一個就是多任務網絡的使用。此時深度學習目標檢測一個基本的框架也就有了一個基本的雛形。只是在Fast RCNN中在候選區域提取的過程依然採用了選擇性搜索策略,因此也會導致整個網絡的過程不是一個完整的端到端的過程,Faster RCNN相對於Fast RCNN的改進就在於RPN網絡的提出,通過對候選框推薦算法來進行改進,使得整個目標檢測的過程由一個完整的深度學習的網絡來完成。

  • ROI Pooling:
  1. SPPNet的一種,也是Pooling層的一種。
  2. 爲了將proposal摳出來的過程,通過相對的座標來摳取feature map上所對應的候選區域的位置,然後將摳取出的feature map resize到統一的大小。對於resize我們通常會採用一些特定的Pooling的操作來完成這樣一個過程。
  3. 操作如下
    1. 根據輸入的image,將Roi映射到feature map對應的位置。我們通常輸入的候選目標它真值的區域通常是相對於原始的image而言的,而Roi Pooling則是作用在經過卷積之後的feature map上的,因此我們需要計算一個相對座標來得到Roi區域所對應的feature map上相應的位置。然後將所得到的位置摳取出來。
    2. 摳取出來之後,將映射後的區域劃分成相同大小的sections(sections數量和輸出的維度相同),或者叫block,或者叫塊、網格結構。具體包含的網格數和輸出維度保持相同。
    3. 對每個section進行max pooling操作得到最終的固定尺寸的輸出。

ROI Pooling作用在不同大小的候選區域能夠得到一個固定的輸出,這個固定的輸出就能夠用於後續的FC層來完成接下來的網絡的位置的迴歸和類別的判定。

  • Fast RCNN網絡性能提升:

在上圖中,我們可以看到,Fast RCNN相比於RCNN在訓練時間和測試時間均有了非常大的提升。Fast RCNN在整個訓練的時間可能需要9.5小時,而RCNN則需要84小時。相對於RCNN,Fast RCNN在訓練速度上能夠得到一個非常大的提升,大概提升了8倍左右。而在測試時間上,RCNN完成單張圖的檢測通常需要消耗47秒,而Fast RCNN在處理單張圖片的時候,完成一個檢測的過程只需要0.32秒,整個檢測的過程速度提高了146倍。

Fast RCNN在性能上的提升意味着深度學習目標檢測算法能夠達到實時處理的可能性,爲後續的研究帶來了曙光。

  • Fast RCNN網絡缺點:
  1. 存在瓶頸:選擇性搜索,找出所有的候選框十分耗時。
  2. 那我們能不能找出一個更加高效的方法來求出這些候選框呢?答案就是後續Faster RCNN的Region Proposal Network(RPN)網絡的提出。代替在傳統的目標檢測算法的proposal提取的過程。

Faster RCNN

Faster RCNN同樣也是Ross B. Girshick提出的一種目標檢測框架

它的整個架構同傳統的目標檢測的架構是非常相似的,但是它的組織形式發生了非常大的改變。首先在卷積層採用了共享卷積的操作,而proposal的提取採用了一個RPN網絡來完成。最終候選區域的類別和位置同樣採用神經網絡來完成。整個過程是一個端到端的檢測過程。也就是說我們將圖片作爲輸入,通過上圖中的網絡經過計算之後就能夠得到一個輸出,這個過程無論是它的繁瑣程度還是運算效率都會得到一個非常大的改進。因此在Faster RCNN提出之後,也意味着Two-stage算法整個框架的成型,在後續的Two-stage目標檢測算法中主要也是沿用了這樣一系列的框架,後續的深度學習目標檢測主要的改進點就在於對其中各個組件的優化。比如說採用更加優秀的主幹網絡來提取更加魯棒的特徵;採用更加嚴謹的ROI Pooling的策略來提升模型的準確度。

Faster RCNN網絡結構

Faster RCNN同Fast RCNN最大的不同就在於proposal提取的過程採用了RPN網絡,其他的組件依然採用了共享的卷積,ROI Pooling和神經網的分類和迴歸來完成最終的目標的檢測和定位。對於主幹網絡包括了13個卷積層和13個激活層(relu)以及4個池化層。這裏4個池化層意味着圖像在進行下采樣的時候則下采樣了2^4=16倍,意味着feature map大小在RPN網絡輸入的時候是原始Image的1/16。

在RPN網絡中並沒有完成最終的目標區域的判定,主要是完成了背景和前景的區分,實際上完成的是一個二分類的問題。同時會完成目標物體的初定位的過程。在後續的過程會通過RPN網絡得到的物體粗略的位置進行進一步的細化。在Faster RCNN中同樣包括了一個ROI Pooling層,利用ROI Pooling層來針對於RPN網絡得到的proposal來進行摳圖和固定尺寸的輸出。將這個固定尺寸的輸出輸入到後續的子網絡中來進一步對目標類別的判定和位置的精修。通過分類和位置精修就能夠得到最終深度學習目標檢測算法的輸出。

  • RPN網絡:

在RPN網絡中主要包括了前、背景分離以及框位置的粗迴歸兩個任務,我們可以將這兩個任務理解爲粗定位粗分類。粗分類是指在RPN網絡中僅僅對proposal,提取出來的候選框或者說候選區域它們是否屬於背景還是屬於前景這樣的一個二分類問題來進行判定。在後續的網絡中,我們會進一步對候選區域它所屬的具體的類別來進行判定。因此在RPN網絡中的分類爲粗分類。粗定位是指在RPN網絡中同樣也會對框的位置進行一次迴歸,這裏的迴歸同樣也是一個粗略的迴歸,是相對於後續的網絡需要對框的位置進行進一步的修訂。也就是說在後一步的位置中會得到更加精確的框的位置。相對來說在RPN網絡中就是對框的位置進行粗略的定位。

RPN網絡代替了傳統目標檢測的proposal候選框位置提取的過程,這個過程實際上就通過Anchor機制來實現。

如上圖所示,在RPN網絡中會通過一個滑動窗口(sliding window),這是一個3*3的窗口。針對於這樣一個滑動窗口的區域,我們通過一個3*3的卷積核來進行卷積,卷積之後就能得到一個固定長度的特徵向量。我們採用3*3*256通道數輸出的256維的特徵向量。針對於每一個3*3的區域,都能得到一個256維的特徵向量,接下來分別用2個FC層來對類別和框的位置進行分類和迴歸。這裏的Anchor是指對於每一個滑動窗口,它的中心點都會作爲一個Anchor。實際上滑動窗口會從左到右、從上到下依次進行滑動,對所輸入到RPN網絡中的feature map每一個點都會作爲一個滑動窗口的中心點,並且這個點被稱爲了一個Anchor。針對於這樣一個Anchor,來分別從原始的圖像中找到不同尺寸的一個窗口,我們認爲這個不同尺寸的窗口經過pooling之後就能夠映射到這樣一個3*3的區域上。換句話說通過Anchor來找到原始的圖像以及在當前的feature map上一個3*3的窗口,它們之間的對應關係。假設我們在原始的圖像中找到9個區域,就認爲這9個區域就是我們所需要提取的proposal。這9個不同尺寸的區域,經過pooling之後都會投影到這樣一個3*3的區域上。此時就能夠根據不同的Anchor來找到多個不同的proposal。比如說當前的feature map大小爲w*h,我們就可以找到w*h*9個proposal(候選區域)。然後針對這些候選區域來估計每一個候選區域的類別和框的座標。

實際上,在Faster RCNN的Anchor,我們會考慮3個面積尺寸(128、256、512),然後分別在這三個不同尺寸下分別取三種不同長寬比的區域(1:1、1:2、2:1)。這樣每一個Anchor都能找到3*3=9個區域,再乘上Anchor的數量(feature map的像素點數量),即此時proposal的個數爲9*w*h。因爲每一個像素點都會成爲Anchor的中心。我們針對9*w*h個proposal進行類別的判定和候選框位置的迴歸,並得到後續的ROI的位置作爲ROI Pooling的輸入,完成後續的位置精修和類別精細判定的一個網絡。

現在我們來實現Faster RCNN各個組織結構的代碼,先是主幹網絡的實現,這裏我們遵照原定義,使用13個卷積層的VGG16模型進行實現。當然這裏可以改成更加複雜的ResNet或者InceptionNet網絡,這裏神經網絡的層次深度越深,可以提取的圖像特徵越好,可以從簡單的顏色、紋理特徵到更高層次的語義特徵。

import tensorflow as tf
from tensorflow.keras import layers, models

class Vgg16:
    def __init__(self):
        self.first_fc_unit = 4096

    def image_to_head(self, inputs):
        x = layers.Conv2D(64, (3, 3), padding='SAME', activation='relu', kernel_regularizer='l2')(inputs)
        x = layers.Conv2D(64, (3, 3), padding='SAME', activation='relu', kernel_regularizer='l2')(x)
        x = layers.MaxPool2D((2, 2), padding='SAME')(x)
        x = layers.Conv2D(128, (3, 3), padding='SAME', activation='relu', kernel_regularizer='l2')(x)
        x = layers.Conv2D(128, (3, 3), padding='SAME', activation='relu', kernel_regularizer='l2')(x)
        x = layers.MaxPool2D((2, 2), padding='SAME')(x)
        x = layers.Conv2D(256, (3, 3), padding='SAME', activation='relu', kernel_regularizer='l2')(x)
        x = layers.Conv2D(256, (3, 3), padding='SAME', activation='relu', kernel_regularizer='l2')(x)
        x = layers.Conv2D(256, (3, 3), padding='SAME', activation='relu', kernel_regularizer='l2')(x)
        x = layers.MaxPool2D((2, 2), padding='SAME')(x)
        x = layers.Conv2D(512, (3, 3), padding='SAME', activation='relu', kernel_regularizer='l2')(x)
        x = layers.Conv2D(512, (3, 3), padding='SAME', activation='relu', kernel_regularizer='l2')(x)
        x = layers.Conv2D(512, (3, 3), padding='SAME', activation='relu', kernel_regularizer='l2')(x)
        x = layers.MaxPool2D((2, 2), padding='SAME')(x)
        x = layers.Conv2D(512, (3, 3), padding='SAME', activation='relu', kernel_regularizer='l2')(x)
        x = layers.Conv2D(512, (3, 3), padding='SAME', activation='relu', kernel_regularizer='l2')(x)
        x = layers.Conv2D(512, (3, 3), padding='SAME', activation='relu', kernel_regularizer='l2')(x)
        return x

    def head_to_tail(self, inputs):
        x = layers.Flatten()(inputs)
        x = layers.Dense(4096, activation='relu', kernel_regularizer='l2')(x)
        x = layers.Dropout(0.1)(x)
        x = layers.Dense(4096, activation='relu', kernel_regularizer='l2')(x)
        x = layers.Dropout(0.1)(x)
        return x

    def build_graph(self, input_shape, class_num):
        inputs = layers.Input(shape=input_shape)
        x = self.image_to_head(inputs=inputs)
        x = layers.Dense(class_num)(x)
        outputs = models.Model(inputs=inputs, outputs=x)
        return outputs


if __name__ == "__main__":
    vgg16 = Vgg16()
    vgg16_model = vgg16.build_graph((500, 500, 3), 10)
    vgg16_model.summary(line_length=100)

運行結果

Model: "model"
____________________________________________________________________________________________________
Layer (type)                                 Output Shape                            Param #        
====================================================================================================
input_1 (InputLayer)                         [(None, 500, 500, 3)]                   0              
____________________________________________________________________________________________________
conv2d (Conv2D)                              (None, 500, 500, 64)                    1792           
____________________________________________________________________________________________________
conv2d_1 (Conv2D)                            (None, 500, 500, 64)                    36928          
____________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)                 (None, 250, 250, 64)                    0              
____________________________________________________________________________________________________
conv2d_2 (Conv2D)                            (None, 250, 250, 128)                   73856          
____________________________________________________________________________________________________
conv2d_3 (Conv2D)                            (None, 250, 250, 128)                   147584         
____________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)               (None, 125, 125, 128)                   0              
____________________________________________________________________________________________________
conv2d_4 (Conv2D)                            (None, 125, 125, 256)                   295168         
____________________________________________________________________________________________________
conv2d_5 (Conv2D)                            (None, 125, 125, 256)                   590080         
____________________________________________________________________________________________________
conv2d_6 (Conv2D)                            (None, 125, 125, 256)                   590080         
____________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)               (None, 63, 63, 256)                     0              
____________________________________________________________________________________________________
conv2d_7 (Conv2D)                            (None, 63, 63, 512)                     1180160        
____________________________________________________________________________________________________
conv2d_8 (Conv2D)                            (None, 63, 63, 512)                     2359808        
____________________________________________________________________________________________________
conv2d_9 (Conv2D)                            (None, 63, 63, 512)                     2359808        
____________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)               (None, 32, 32, 512)                     0              
____________________________________________________________________________________________________
conv2d_10 (Conv2D)                           (None, 32, 32, 512)                     2359808        
____________________________________________________________________________________________________
conv2d_11 (Conv2D)                           (None, 32, 32, 512)                     2359808        
____________________________________________________________________________________________________
conv2d_12 (Conv2D)                           (None, 32, 32, 512)                     2359808        
____________________________________________________________________________________________________
dense (Dense)                                (None, 32, 32, 10)                      5130           
====================================================================================================
Total params: 14,719,818
Trainable params: 14,719,818
Non-trainable params: 0

第二步當然是從主幹網絡出來的feature map進入RPN網絡。之前在介紹SSD的時候,我們說過不僅會輸入圖像數據,還有Bounding box的真值數據,也就是標註信息,它包含5個維度——類別信息以及4個座標值(xmin、ymin、xmax、ymax),具體可以參考Tensorflow的圖像操作 Tensorflow+SSD實戰

 這裏我們以一張輸入圖像尺寸爲500*600的3通道彩色圖片爲例來說明,但值得注意的是,我們輸入的圖像可以是任意尺寸的。

import tensorflow as tf
from tensorflow.keras import layers, preprocessing, backend, models, optimizers, losses
from vgg16 import Vgg16
import numpy as np
from skimage.transform import resize
import cv2
from data.generate_voc_data import DataGenerator

if __name__ == "__main__":

    # im_input_shape = (None, None, 3)
    img = preprocessing.image.load_img("/Users/admin/Documents/6565.jpeg", color_mode='rgb')
    print(np.asarray(img).shape)
    img = preprocessing.image.img_to_array(img, dtype=np.uint8)
    img = resize(img, (500, 600, 3))
    r, g, b = cv2.split(img)
    img_new = cv2.merge((b, g, r))
    cv2.imshow('img', img_new)
    cv2.waitKey()
    img = tf.reshape(img, shape=(1, 500, 600, 3))
    img = tf.cast(img, dtype=tf.float32)
    im_input_shape = (500, 600, 3)
    gt_box_input_shape = (None, 5)
    batch_size = 1

    im_inputs = layers.Input(shape=im_input_shape, batch_size=batch_size)
    # gt_boxes = layers.Input(shape=gt_box_input_shape, batch_size=batch_size)
    gt_boxes = tf.constant(np.array([[21.2, 33.1, 461.2, 561.3, 1], [97.6, 88.9, 288.1, 367.9, 2]]), dtype=tf.float32)
    print(gt_boxes.shape)
    # 整合輸入圖片的高和寬
    im_info = tf.cast([tf.shape(img)[1], tf.shape(img)[2]], dtype=tf.float32, name="im_info_cast")
    vgg16 = Vgg16()
    # 將圖片進行卷積得到feature map,並獲取feature map的高和寬
    feature_map = vgg16.image_to_head(img)
    feature_map_height = tf.shape(feature_map)[1]
    feature_map_width = tf.shape(feature_map)[2]
    print(feature_map)

運行結果

(500, 600, 3)
(2, 5)
tf.Tensor(
[[[[0.0000000e+00 3.0219162e-04 1.9649921e-04 ... 9.5941679e-05
    0.0000000e+00 0.0000000e+00]
   [0.0000000e+00 3.4777069e-04 1.4632757e-04 ... 0.0000000e+00
    0.0000000e+00 0.0000000e+00]
   [0.0000000e+00 4.5308818e-05 2.7419327e-04 ... 0.0000000e+00
    0.0000000e+00 0.0000000e+00]
   ...
   [0.0000000e+00 2.3457017e-04 3.4940938e-04 ... 0.0000000e+00
    0.0000000e+00 0.0000000e+00]
   [0.0000000e+00 6.3231288e-05 3.5290085e-04 ... 0.0000000e+00
    0.0000000e+00 0.0000000e+00]
   [0.0000000e+00 0.0000000e+00 3.0360566e-04 ... 0.0000000e+00
    0.0000000e+00 0.0000000e+00]]

  [[0.0000000e+00 6.1431236e-04 2.8230989e-04 ... 3.3268472e-05
    0.0000000e+00 5.3799915e-05]
   [0.0000000e+00 6.0494588e-04 2.0055997e-04 ... 0.0000000e+00
    0.0000000e+00 1.2344688e-04]
   [0.0000000e+00 3.6424759e-04 3.6444786e-04 ... 0.0000000e+00
    0.0000000e+00 2.4745914e-05]
   ...
   [0.0000000e+00 4.5090867e-04 5.4513843e-04 ... 0.0000000e+00
    0.0000000e+00 0.0000000e+00]
   [0.0000000e+00 9.5991025e-05 5.6906225e-04 ... 7.8708392e-05
    0.0000000e+00 0.0000000e+00]
   [0.0000000e+00 0.0000000e+00 4.2226817e-04 ... 2.8439818e-04
    0.0000000e+00 0.0000000e+00]]

  [[0.0000000e+00 7.8332843e-04 4.0208295e-04 ... 6.9430265e-05
    0.0000000e+00 2.0329645e-04]
   [0.0000000e+00 7.0123601e-04 3.6532484e-04 ... 5.6255354e-05
    0.0000000e+00 2.3254988e-04]
   [0.0000000e+00 3.2104889e-04 4.3427886e-04 ... 1.4044531e-04
    0.0000000e+00 7.9119789e-05]
   ...
   [0.0000000e+00 4.7177816e-04 6.6962885e-04 ... 0.0000000e+00
    0.0000000e+00 0.0000000e+00]
   [0.0000000e+00 0.0000000e+00 6.2687905e-04 ... 5.2743500e-05
    0.0000000e+00 0.0000000e+00]
   [0.0000000e+00 0.0000000e+00 5.7314354e-04 ... 2.3048952e-04
    0.0000000e+00 0.0000000e+00]]

  ...

  [[0.0000000e+00 4.3245990e-04 1.5730722e-04 ... 2.4690776e-04
    0.0000000e+00 6.2117389e-05]
   [0.0000000e+00 1.4534194e-04 1.4233941e-04 ... 5.3184008e-04
    0.0000000e+00 0.0000000e+00]
   [6.4062777e-05 0.0000000e+00 1.2227474e-04 ... 3.8292457e-04
    0.0000000e+00 0.0000000e+00]
   ...
   [1.9968928e-04 7.8673183e-06 3.2361579e-04 ... 7.7268225e-04
    0.0000000e+00 0.0000000e+00]
   [2.5986420e-04 0.0000000e+00 4.0086202e-04 ... 7.3863228e-04
    0.0000000e+00 0.0000000e+00]
   [1.5432760e-04 0.0000000e+00 4.4641300e-04 ... 9.4485411e-05
    0.0000000e+00 0.0000000e+00]]

  [[0.0000000e+00 3.8628629e-04 1.4483498e-04 ... 1.1889226e-04
    0.0000000e+00 0.0000000e+00]
   [6.0426763e-05 1.0252794e-04 1.2850111e-04 ... 3.2820049e-04
    0.0000000e+00 0.0000000e+00]
   [7.3168143e-05 7.1833842e-05 2.0838881e-04 ... 3.2303098e-04
    0.0000000e+00 0.0000000e+00]
   ...
   [3.9218267e-04 8.4534899e-05 3.4109424e-04 ... 7.3586428e-04
    0.0000000e+00 0.0000000e+00]
   [3.9436651e-04 0.0000000e+00 3.9049587e-04 ... 5.8985659e-04
    0.0000000e+00 0.0000000e+00]
   [3.5940378e-04 0.0000000e+00 3.1699456e-04 ... 9.4989351e-05
    0.0000000e+00 0.0000000e+00]]

  [[2.1385179e-04 1.9172599e-04 0.0000000e+00 ... 3.4191049e-05
    0.0000000e+00 0.0000000e+00]
   [2.9973016e-04 2.9921081e-05 0.0000000e+00 ... 2.2193138e-04
    0.0000000e+00 0.0000000e+00]
   [4.2951890e-04 0.0000000e+00 0.0000000e+00 ... 8.5445106e-05
    0.0000000e+00 0.0000000e+00]
   ...
   [4.4596923e-04 0.0000000e+00 0.0000000e+00 ... 2.9606459e-04
    0.0000000e+00 0.0000000e+00]
   [3.5400479e-04 0.0000000e+00 0.0000000e+00 ... 2.7514625e-04
    0.0000000e+00 0.0000000e+00]
   [1.0696228e-04 0.0000000e+00 1.2458675e-04 ... 2.2076641e-05
    0.0000000e+00 0.0000000e+00]]]], shape=(1, 32, 38, 512), dtype=float32)

我們導入的是一張真實圖片,它的尺寸爲500*600*3。這裏的feature map經過主幹網絡卷積後得到的結果爲(1, 32, 38, 512),然後是RPN網絡的Anchor機制。這裏我們將整個Anchor機制當成神經網絡中的一個層來看待,輸入的是feature map圖像尺寸的高和寬。之前我們知道,feature map經過4個最大池化層的下采樣後變爲原圖像的1/16,現在我們要將feature map上的每一個點映射到原圖上,我們先看看featue map的寬和高在原圖上對應的座標值。

# 特徵圖寬和高在原圖的位置
feat_stride = 16
shift_x = tf.multiply(tf.range(feature_map_width, name='range_shift_x'), feat_stride)
shift_y = tf.multiply(tf.range(feature_map_height, name='range_shift_y'), feat_stride)
print(shift_x)
print(shift_y)

運行結果

tf.Tensor(
[  0  16  32  48  64  80  96 112 128 144 160 176 192 208 224 240 256 272
 288 304 320 336 352 368 384 400 416 432 448 464 480 496 512 528 544 560
 576 592], shape=(38,), dtype=int32)
tf.Tensor(
[  0  16  32  48  64  80  96 112 128 144 160 176 192 208 224 240 256 272
 288 304 320 336 352 368 384 400 416 432 448 464 480 496], shape=(32,), dtype=int32)

由於這裏只是對應寬和高的映射,我們需要feature map所有點在原圖上對應的座標。這裏我們將直接採用上面得到的映射值來進行下面的處理。

# 合成網格矩陣
shift_x, shift_y = tf.meshgrid(shift_x, shift_y, name="meshgrid_x_y")
print('shift_x', shift_x)
print('shift_y', shift_y)
# 扁平化
sx = tf.reshape(shift_x, shape=(-1,), name='reshape_sx')
sy = tf.reshape(shift_y, shape=(-1,), name='reshape_sy')
print('sx', sx)
print('sy', sy)
# 合併
xyxy = tf.stack([sx, sy, sx, sy], name='stack_xyxy')
print('xyxy', xyxy)
# 轉置
shifts = tf.transpose(xyxy, name='transpose_shifts')
print('shifts', shifts)

運行結果

shift_x tf.Tensor(
[[  0  16  32 ... 560 576 592]
 [  0  16  32 ... 560 576 592]
 [  0  16  32 ... 560 576 592]
 ...
 [  0  16  32 ... 560 576 592]
 [  0  16  32 ... 560 576 592]
 [  0  16  32 ... 560 576 592]], shape=(32, 38), dtype=int32)
shift_y tf.Tensor(
[[  0   0   0 ...   0   0   0]
 [ 16  16  16 ...  16  16  16]
 [ 32  32  32 ...  32  32  32]
 ...
 [464 464 464 ... 464 464 464]
 [480 480 480 ... 480 480 480]
 [496 496 496 ... 496 496 496]], shape=(32, 38), dtype=int32)
sx tf.Tensor([  0  16  32 ... 560 576 592], shape=(1216,), dtype=int32)
sy tf.Tensor([  0   0   0 ... 496 496 496], shape=(1216,), dtype=int32)
xyxy tf.Tensor(
[[  0  16  32 ... 560 576 592]
 [  0   0   0 ... 496 496 496]
 [  0  16  32 ... 560 576 592]
 [  0   0   0 ... 496 496 496]], shape=(4, 1216), dtype=int32)
shifts tf.Tensor(
[[  0   0   0   0]
 [ 16   0  16   0]
 [ 32   0  32   0]
 ...
 [560 496 560 496]
 [576 496 576 496]
 [592 496 592 496]], shape=(1216, 4), dtype=int32)
# 獲取feature map所有像素點的個數
K = tf.multiply(feature_map_width, feature_map_height, name='multi_w_h')
print(K)

運行結果

tf.Tensor(1216, shape=(), dtype=int32)

由結果可知,feature map總共有1216個像素點。

shifts_reshape = tf.reshape(shifts, shape=[1, K, 4], name='shifts_reshape')
print('shifts_reshape', shifts_reshape)
# 交換0、1兩個維度的位置,獲取feature map映射在原圖的1216箇中心點位置
shifts = tf.transpose(shifts_reshape, perm=(1, 0, 2))
print('shifts', shifts)

運行結果

shifts_reshape tf.Tensor(
[[[  0   0   0   0]
  [ 16   0  16   0]
  [ 32   0  32   0]
  ...
  [560 496 560 496]
  [576 496 576 496]
  [592 496 592 496]]], shape=(1, 1216, 4), dtype=int32)
shifts tf.Tensor(
[[[  0   0   0   0]]

 [[ 16   0  16   0]]

 [[ 32   0  32   0]]

 ...

 [[560 496 560 496]]

 [[576 496 576 496]]

 [[592 496 592 496]]], shape=(1216, 1, 4), dtype=int32)

得到這些映射點之後,我們就可以來產生Bounding box了。這裏我們設置一個函數來生成候選區域的邊框。

# 寬高比例1:2、1:1、2:1
anchor_ratios = (0.5, 1, 2)
# 面積尺寸常數2^3、2^4、2^5
anchor_scales = (8, 16, 32)
anchors = generate_anchors(ratios=np.array(anchor_ratios), scales=np.array(anchor_scales))

現在我們來看一下generate_anchors這個函數

def generate_anchors(base_size=16,
                     ratios=[0.5, 1., 2.],
                     scales=2 ** np.arange(3, 6)):
    '''
    生成多尺度、多寬高比的anchors
    '''
    # 生成一個基礎大小爲16*16的Bounding box,feature map的32*38=1216個點對應於原圖
    # 500*600=300000上一個16*16的區域
    base_anchor = np.array([1, 1, base_size, base_size]) - 1
    # 枚舉各種寬高比
    ratio_anchors = _ratio_enum(base_anchor, ratios)
    # 枚舉各種尺度
    anchors = np.vstack([_scale_enum(ratio_anchors[i, :], scales)
                         for i in range(ratio_anchors.shape[0])])
    return anchors
def _ratio_enum(anchor, ratios):
    """
    列舉關於一個anchor的三種寬高比 1:2,1:1,2:1
    """
    # 返回寬高和中心座標
    w, h, x_ctr, y_ctr = _whctrs(anchor)
    # 這裏w=16,h=16,size=16*16=256
    size = w * h
    # 256/ratios[0.5,1,2]=[512,256,128]
    # 此處對應上了3個面積尺寸(128、256、512)
    size_ratios = size / ratios
    # ws:[23 16 11]
    ws = np.round(np.sqrt(size_ratios))
    # hs:[12 16 22]
    hs = np.round(ws * ratios)
    # 給定一組寬高向量,輸出各個預測窗口
    anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
    return anchors

def _whctrs(anchor):
    """
    用於返回width,height,(x,y)中心座標(對於一個anchor窗口)
    """
    # anchor:存儲了窗口左上角,右下角的座標
    w = anchor[2] - anchor[0] + 1
    h = anchor[3] - anchor[1] + 1
    # anchor中心點座標
    x_ctr = anchor[0] + 0.5 * (w - 1)
    y_ctr = anchor[1] + 0.5 * (h - 1)
    return w, h, x_ctr, y_ctr


def _mkanchors(ws, hs, x_ctr, y_ctr):
    """
    給定一組寬高向量,輸出各個anchor,即預測窗口,輸出anchor的面積相等,只是寬高比不同
    """
    # ws: [[23], [16], [11]]
    ws = ws[:, np.newaxis]
    # hs: [[12], [16], [22]]
    hs = hs[:, np.newaxis]
    # 根據中心座標得到輸出窗口左上和右下兩點座標
    anchors = np.hstack((x_ctr - 0.5 * (ws - 1),
                         y_ctr - 0.5 * (hs - 1),
                         x_ctr + 0.5 * (ws - 1),
                         y_ctr + 0.5 * (hs - 1)))
    return anchors

def _scale_enum(anchor, scales):
    """
    列舉關於一個anchor的三種尺度 128*128,256*256,512*512
    """
    # 返回寬高和中心座標
    w, h, x_ctr, y_ctr = _whctrs(anchor)
    # [128 256 512]
    ws = w * scales
    # [128 256 512]
    hs = h * scales
    # 給定一組寬高向量,輸出各個預測窗口
    anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
    return anchors

將獲取的預測邊框打印一下

print('anchors', anchors)
print(anchors.shape)

運行結果

anchors [[ -84.  -40.   99.   55.]
 [-176.  -88.  191.  103.]
 [-360. -184.  375.  199.]
 [ -56.  -56.   71.   71.]
 [-120. -120.  135.  135.]
 [-248. -248.  263.  263.]
 [ -36.  -80.   51.   95.]
 [ -80. -168.   95.  183.]
 [-168. -344.  183.  359.]]
(9, 4)
# A=9表示提取的Anchor的9種預測邊框
A = anchors.shape[0]
anchor_constant = tf.reshape(anchors, (1, A, 4), name='anchor_constant')
anchor_constant = tf.cast(anchor_constant, dtype=tf.int32, name='anchor_constant_cast')
print('anchor_constant', anchor_constant)
# 1216*9=10944,表示1箇中心點有9個預測框
# 一共有1216箇中心點,共10944個預測框
length = tf.multiply(K, A, name='length')
print('length', length)
# 將featuremap映射在原圖的1216箇中心點位置全部轉換成9個預測框的座標
# 共10944個預測框座標
anchors_add_shifts = tf.add(anchor_constant, shifts, name='anchors_add_shifts')
print('anchors_add_shifts', anchors_add_shifts)

運行結果

anchor_constant tf.Tensor(
[[[ -84  -40   99   55]
  [-176  -88  191  103]
  [-360 -184  375  199]
  [ -56  -56   71   71]
  [-120 -120  135  135]
  [-248 -248  263  263]
  [ -36  -80   51   95]
  [ -80 -168   95  183]
  [-168 -344  183  359]]], shape=(1, 9, 4), dtype=int32)
length tf.Tensor(10944, shape=(), dtype=int32)
anchors_add_shifts tf.Tensor(
[[[ -84  -40   99   55]
  [-176  -88  191  103]
  [-360 -184  375  199]
  ...
  [ -36  -80   51   95]
  [ -80 -168   95  183]
  [-168 -344  183  359]]

 [[ -68  -40  115   55]
  [-160  -88  207  103]
  [-344 -184  391  199]
  ...
  [ -20  -80   67   95]
  [ -64 -168  111  183]
  [-152 -344  199  359]]

 [[ -52  -40  131   55]
  [-144  -88  223  103]
  [-328 -184  407  199]
  ...
  [  -4  -80   83   95]
  [ -48 -168  127  183]
  [-136 -344  215  359]]

 ...

 [[ 476  456  659  551]
  [ 384  408  751  599]
  [ 200  312  935  695]
  ...
  [ 524  416  611  591]
  [ 480  328  655  679]
  [ 392  152  743  855]]

 [[ 492  456  675  551]
  [ 400  408  767  599]
  [ 216  312  951  695]
  ...
  [ 540  416  627  591]
  [ 496  328  671  679]
  [ 408  152  759  855]]

 [[ 508  456  691  551]
  [ 416  408  783  599]
  [ 232  312  967  695]
  ...
  [ 556  416  643  591]
  [ 512  328  687  679]
  [ 424  152  775  855]]], shape=(1216, 9, 4), dtype=int32)
# 將第一和第二維度扁平化,只保留預測框的個數和座標值
anchors_tf = tf.reshape(anchors_add_shifts, shape=(length, 4), name='anchors_tf')
anchors_tf_cast = tf.cast(anchors_tf, dtype=tf.float32, name='anchors_tf_cast')
print('anchors_tf_cast', anchors_tf_cast)

運行結果

anchors_tf_cast tf.Tensor(
[[ -84.  -40.   99.   55.]
 [-176.  -88.  191.  103.]
 [-360. -184.  375.  199.]
 ...
 [ 556.  416.  643.  591.]
 [ 512.  328.  687.  679.]
 [ 424.  152.  775.  855.]], shape=(10944, 4), dtype=float32)

現在我們開始搭建整個RPN網絡,進行粗定位和粗分類

is_training = True
if is_training:
    # rpn, 第一次bbox迴歸, class分類
    # rois, roi_scores, labels, anchor_targets, proposal_targets, predictions = \
    region_proposal_network(conv_net=feature_map,
                            anchors=anchors_tf_cast,
                            gt_boxes=gt_boxes,
                            im_info=im_info,
                            is_training=is_training)

在整個RPN網絡搭建的時候,我們可以先看一下這個圖

上面這個圖其實就是region_proposal_network函數的定義

def region_proposal_network(conv_net, anchors, gt_boxes, im_info, is_training):
    """
    rpn網絡, 對上個卷積網絡輸出的特徵層做 類別預測和邊框預測
    """
    anchor_targets = {}
    predictions = {}
    proposal_targets = {}

    # 共享層卷積,進一步提取特徵,conv_net就是vgg16網絡輸出的feature map
    rpn = layers.Conv2D(512, (3, 3), padding='SAME', kernel_regularizer='l2')(conv_net)
    # 類別預測,是否爲前景,9爲每一個Anchor中心有9種類型的預測框
    # 逐像素對其9個Anchorbox進行二分類
    rpn_cls_score = layers.Conv2D(9 * 2, (1, 1), padding='VALID', kernel_regularizer='l2')(rpn)
    print('rpn_cls_score', rpn_cls_score)
    # reshape成2個通道
    rpn_cls_score_reshape = tf.reshape(rpn_cls_score, (-1, 2))
    print('rpn_cls_score_reshape', rpn_cls_score_reshape)
    # 對通道層做分類歸一化
    rpn_cls_prob_reshape = layers.Softmax()(rpn_cls_score_reshape)
    print('rpn_cls_prob_reshape', rpn_cls_prob_reshape)
    # 預測值
    rpn_cls_pred = tf.argmax(tf.reshape(rpn_cls_score_reshape, [-1, 2]), axis=1, name="rpn_cls_pred")
    print('rpn_cls_pred', rpn_cls_pred)
    # 轉換回[1, h, w, 9*2]
    rpn_cls_prob = tf.reshape(rpn_cls_prob_reshape, tf.shape(rpn_cls_score))
    print('rpn_cls_prob', rpn_cls_prob)

運行結果

rpn_cls_score tf.Tensor(
[[[[ 1.28591113e-04 -1.19735530e-04  8.66432092e-05 ...  2.88864394e-04
    -4.07604326e-04 -1.69272156e-04]
   [ 2.55654915e-04 -8.35588071e-05 -1.04449326e-04 ...  2.39458124e-04
    -5.06101234e-04 -5.30362653e-04]
   [ 3.16421676e-04 -1.26520536e-04  9.92472123e-05 ...  3.97705706e-04
    -3.85666150e-04 -6.93914888e-04]
   ...
   [ 1.20688761e-04 -1.99504429e-04  9.57472657e-05 ...  2.42701033e-04
    -2.55647494e-04 -6.55716402e-04]
   [-1.03395330e-04 -1.21175937e-04  1.56538765e-04 ...  1.54868467e-04
    -1.11868736e-04 -4.43565950e-04]
   [-1.39908167e-04 -1.05267529e-04  1.13623471e-04 ... -1.56470633e-05
    -4.12553418e-05 -4.12014575e-04]]

  [[ 1.49776650e-04 -9.10804374e-05  2.88046431e-04 ...  3.02948640e-04
    -3.63217317e-04 -1.90231280e-04]
   [ 3.36535100e-04 -7.66881567e-06  7.98430701e-05 ... -7.11385364e-05
    -5.36559557e-04 -6.61224243e-04]
   [ 3.84626997e-04 -3.68737819e-05  2.20809336e-04 ...  7.86082819e-05
    -4.07071697e-04 -9.39455058e-04]
   ...
   [ 2.02674215e-04 -1.89493439e-04  3.11142867e-06 ...  2.52468861e-04
    -3.70154681e-04 -7.24348938e-04]
   [-1.27771782e-04 -4.68083017e-05  1.03595143e-04 ...  1.01013677e-04
    -1.95325614e-04 -5.66674047e-04]
   [-2.48806085e-04  4.83075928e-06 -1.46968305e-05 ... -8.54994578e-05
    -1.23633450e-04 -8.11211416e-04]]

  [[ 1.08757929e-04  1.71142165e-06  2.16108936e-04 ...  2.87988601e-04
    -2.88046314e-04 -2.15804204e-04]
   [ 5.20187314e-04  4.18786367e-05  7.55323927e-05 ... -8.08777622e-05
    -5.44671610e-04 -5.80734690e-04]
   [ 4.77890950e-04 -4.85706405e-05  3.39652121e-04 ...  5.67078387e-05
    -3.85383406e-04 -9.47830267e-04]
   ...
   [ 2.25324984e-04 -4.68418963e-04  1.25313731e-04 ...  3.74071707e-04
    -3.56081349e-04 -8.71837721e-04]
   [ 8.46904877e-05 -2.14670261e-04  1.91344385e-04 ...  2.19033405e-04
    -1.13443792e-04 -8.45648930e-04]
   [-2.25353113e-04 -1.58908137e-04 -1.55436093e-04 ... -1.15734685e-04
    -1.64380966e-04 -1.23383116e-03]]

  ...

  [[ 2.74988357e-04  1.35776325e-04  1.90707302e-04 ...  7.08626976e-05
    -1.85924961e-04  2.51720747e-04]
   [ 5.98442683e-04  1.32858098e-04  2.35301297e-04 ... -3.01424996e-04
    -3.14001692e-04 -2.85727932e-04]
   [ 4.83462121e-04 -8.81854794e-05  6.13140874e-04 ... -2.69158510e-04
    -3.77402728e-04 -4.99104965e-04]
   ...
   [ 6.20365725e-04 -4.86295554e-04  8.08023557e-04 ... -4.15926916e-04
    -6.74676732e-04 -5.71158133e-04]
   [ 4.36893781e-04 -4.87219484e-04  8.22506379e-04 ... -4.02537291e-04
    -4.83523821e-04 -7.54523149e-04]
   [-9.37685327e-05 -3.76466109e-04  5.01742063e-04 ... -3.61098646e-04
    -6.76137744e-04 -9.69226006e-04]]

  [[ 2.57291249e-04  1.71593216e-04  1.39481446e-04 ... -6.48936984e-05
    -9.34292475e-05  1.93091910e-04]
   [ 4.63203934e-04  1.43579920e-04  1.70695566e-04 ... -2.55600840e-04
    -9.44221101e-05 -2.44561932e-04]
   [ 4.48389910e-04 -7.37725713e-05  3.07883398e-04 ... -3.30948038e-04
    -8.29529599e-05 -5.65320021e-04]
   ...
   [ 6.21744723e-04 -3.55675351e-04  5.75212180e-04 ... -2.07936682e-04
    -5.43940230e-04 -7.03393365e-04]
   [ 4.22565761e-04 -3.90416360e-04  5.50721132e-04 ... -2.39470261e-04
    -2.63691705e-04 -7.76427216e-04]
   [-1.02703561e-05 -1.18101110e-04  2.71086057e-04 ... -2.65686569e-04
    -4.73425083e-04 -8.79082421e-04]]

  [[ 6.92889298e-05  1.57802220e-04  1.64041339e-05 ... -1.62374170e-04
    -2.70871533e-05  3.24778375e-04]
   [ 1.32698027e-04 -9.20765306e-05  5.00971364e-05 ... -3.04940884e-04
    -3.09051029e-05  1.08687324e-04]
   [ 6.81239108e-05 -1.21253353e-04  2.34147767e-04 ... -4.10851702e-04
    -5.17751105e-05  1.76154426e-05]
   ...
   [ 8.64961403e-05 -2.83809932e-04  6.33362331e-04 ... -5.19489229e-04
    -5.11460064e-04 -1.83705750e-04]
   [-9.00256127e-05 -2.46202835e-04  5.32777165e-04 ... -4.89840168e-04
    -3.79110454e-04 -3.53413867e-04]
   [-1.18138312e-04 -1.33659603e-04  3.29175964e-04 ... -3.05922411e-04
    -3.96786316e-04 -4.52264474e-04]]]], shape=(1, 32, 38, 18), dtype=float32)
rpn_cls_score_reshape tf.Tensor(
[[ 1.2859111e-04 -1.1973553e-04]
 [ 8.6643209e-05  1.6108275e-04]
 [-2.1313256e-04  3.2927433e-05]
 ...
 [ 5.8368019e-05  2.1016470e-04]
 [ 1.8138027e-04 -3.0592241e-04]
 [-3.9678632e-04 -4.5226447e-04]], shape=(10944, 2), dtype=float32)
rpn_cls_prob_reshape tf.Tensor(
[[0.50006205 0.4999379 ]
 [0.49998137 0.5000186 ]
 [0.4999385  0.5000615 ]
 ...
 [0.49996206 0.50003797]
 [0.50012183 0.49987817]
 [0.5000139  0.49998614]], shape=(10944, 2), dtype=float32)
rpn_cls_pred tf.Tensor([0 1 1 ... 1 0 0], shape=(10944,), dtype=int64)
rpn_cls_prob tf.Tensor(
[[[[0.50006205 0.4999379  0.49998137 ... 0.50003326 0.49994043
    0.5000596 ]
   [0.5000848  0.4999152  0.49990672 ... 0.5000105  0.5000061
    0.49999395]
   [0.5001107  0.49988922 0.49997547 ... 0.50008696 0.50007707
    0.49992293]
   ...
   [0.50008005 0.49991995 0.49991325 ... 0.5000954  0.5001
    0.49989998]
   [0.5000045  0.4999956  0.49995524 ... 0.5000671  0.5000829
    0.49991706]
   [0.49999133 0.50000864 0.5000078  ... 0.5000597  0.5000927
    0.4999073 ]]

  [[0.5000602  0.49993977 0.50006527 ... 0.5000326  0.4999568
    0.5000433 ]
   [0.50008607 0.49991396 0.50006646 ... 0.49988762 0.5000312
    0.49996883]
   [0.5001054  0.49989462 0.5001057  ... 0.49998683 0.5001331
    0.4998669 ]
   ...
   [0.50009805 0.49990198 0.4999725  ... 0.50010264 0.5000885
    0.49991143]
   [0.4999798  0.50002027 0.50003725 ... 0.5000665  0.5000928
    0.49990714]
   [0.4999366  0.5000634  0.49998552 ... 0.4999955  0.5001719
    0.4998281 ]]

  [[0.50002676 0.49997324 0.500081   ... 0.49997538 0.49998194
    0.50001806]
   [0.50011957 0.4998804  0.5001116  ... 0.4998631  0.500009
    0.49999097]
   [0.50013167 0.49986842 0.50017715 ... 0.49997088 0.5001406
    0.4998594 ]
   ...
   [0.50017345 0.49982658 0.5000165  ... 0.5001252  0.500129
    0.4998711 ]
   [0.5000748  0.49992514 0.50006133 ... 0.5000737  0.50018305
    0.49981695]
   [0.4999834  0.5000166  0.49997988 ... 0.49997655 0.5002673
    0.4997326 ]]

  ...

  [[0.5000348  0.4999652  0.5001331  ... 0.49991515 0.49989057
    0.5001094 ]
   [0.50011635 0.49988356 0.5001697  ... 0.4997782  0.49999297
    0.5000071 ]
   [0.50014293 0.4998571  0.50028276 ... 0.49974218 0.50003046
    0.4999696 ]
   ...
   [0.5002767  0.49972335 0.5003303  ... 0.49972644 0.4999741
    0.50002587]
   [0.500231   0.49976897 0.50034046 ... 0.4997262  0.5000678
    0.49993226]
   [0.5000707  0.49992934 0.5001798  ... 0.49982387 0.50007325
    0.49992672]]

  [[0.50002146 0.4999786  0.50008607 ... 0.49993128 0.49992839
    0.50007164]
   [0.5000799  0.49992007 0.5001129  ... 0.4998723  0.50003755
    0.49996248]
   [0.50013053 0.49986947 0.5001542  ... 0.4998455  0.5001206
    0.4998794 ]
   ...
   [0.5002444  0.49975568 0.5002672  ... 0.49983895 0.5000399
    0.49996015]
   [0.50020325 0.49979675 0.50024486 ... 0.49983242 0.5001282
    0.49987185]
   [0.50002694 0.49997303 0.5001265  ... 0.49985176 0.5001014
    0.49989855]]

  [[0.49997786 0.5000221  0.50007534 ... 0.49990255 0.49991205
    0.500088  ]
   [0.5000562  0.49994382 0.50011855 ... 0.49985442 0.49996513
    0.5000349 ]
   [0.5000473  0.49995264 0.50017565 ... 0.49981332 0.49998266
    0.50001734]
   ...
   [0.50009257 0.4999074  0.50032103 ... 0.4997933  0.49991807
    0.50008196]
   [0.50003904 0.49996096 0.5002959  ... 0.49980077 0.4999936
    0.50000644]
   [0.5000039  0.49999613 0.5001537  ... 0.49987817 0.5000139
    0.49998614]]]], shape=(1, 32, 38, 18), dtype=float32)

region_proposal_network函數繼續

# 邊框預測,注意這裏預測的不是 left, bottom, right, top而是anchor與真實框之間的偏移值。
rpn_bbox_pred = layers.Conv2D(9 * 4, (1, 1), padding='VALID', kernel_regularizer='l2')(rpn)
print('rpn_bbox_pred', rpn_bbox_pred)

if is_training:
    # 預測的邊框與anchors進行比對, 非極大抑制後輸出最終目標邊框[[0, x1, y1, x2, y2],...]及其分值
    scores = tf.reshape(rpn_cls_prob, (-1, 2))[:, 1]
    rpn_bbox_pred = tf.reshape(rpn_bbox_pred, shape=(-1, 4))
    print('scores', scores)
    print('rpn_bbox_pred', rpn_bbox_pred)
    # 根據anchors和偏移量得到proposals(候選區域)
    proposals = bbox_transform_inv_tf(anchors, rpn_bbox_pred)
    print('proposals', proposals)

運行結果

rpn_bbox_pred tf.Tensor(
[[[[ 3.94443196e-05  6.92607427e-05  4.00393765e-05 ...  1.22013189e-04
    -2.10186525e-04  1.58461131e-04]
   [-7.33447159e-05 -1.00469137e-04  1.53445682e-04 ...  2.60028843e-04
    -3.73837072e-04  6.37015910e-05]
   [-1.56758804e-04 -8.66888222e-05 -4.62429452e-05 ...  2.84914131e-04
    -3.68396693e-04 -6.33426971e-05]
   ...
   [-2.70735589e-04  8.23117443e-05 -1.75508481e-04 ...  2.26667558e-04
    -3.99926968e-04 -2.47621385e-04]
   [-3.75439879e-04  1.01676538e-04 -2.11953957e-04 ...  2.15460568e-05
    -2.11980019e-04 -3.59909434e-04]
   [-2.55937048e-04  1.28262309e-05 -9.20578968e-05 ...  9.42961924e-06
     9.30242386e-05 -1.36353803e-04]]

  [[ 3.23344284e-05  1.26386323e-04 -2.43344621e-05 ...  2.27981189e-04
    -2.63415859e-04 -1.10090186e-05]
   [-5.38501481e-05 -1.08666136e-04  1.72905216e-04 ...  2.16895831e-04
    -6.81167119e-04 -5.84048321e-05]
   [-8.37506741e-05 -1.97177404e-04 -4.39312134e-05 ...  2.23244628e-04
    -6.05323992e-04 -2.43602190e-04]
   ...
   [-1.72304979e-04  3.09660027e-05 -5.51066260e-05 ...  1.10222783e-04
    -7.49949482e-04 -2.74291961e-04]
   [-2.78306921e-04  1.05404150e-04 -1.39134208e-04 ... -2.18430447e-04
    -4.07704647e-04 -3.52783041e-04]
   [-2.27868732e-04 -4.20076794e-05  5.77527462e-05 ... -1.33788591e-04
    -2.66521383e-05 -1.57433329e-04]]

  [[ 1.26381114e-04  2.36857828e-04 -4.49343279e-05 ...  2.70184013e-04
    -4.55210684e-04  7.47668819e-05]
   [ 7.66618323e-05  4.51652741e-05  2.26735152e-04 ...  2.06948214e-04
    -8.77024489e-04  3.30306793e-05]
   [-2.87178700e-05 -9.91490742e-05 -3.63895524e-05 ...  1.31564200e-04
    -6.95443421e-04 -1.54453504e-04]
   ...
   [-8.94593977e-05  2.87120696e-04 -7.03832775e-05 ...  1.17514035e-04
    -7.71688530e-04 -1.51907501e-04]
   [-2.05144679e-04  2.40545342e-04 -8.76502163e-05 ... -2.32767721e-04
    -3.95883369e-04 -1.98556343e-04]
   [-2.18652611e-04  1.10108122e-05  3.09899478e-05 ... -2.36096195e-04
    -1.22801648e-04 -1.07635897e-04]]

  ...

  [[ 1.58334660e-04  2.44860712e-04  7.73576830e-05 ...  1.13771785e-04
    -3.13729455e-04 -2.07834819e-04]
   [ 1.11568668e-04  1.59548974e-04  2.74635444e-04 ... -3.39365870e-05
    -4.80627408e-04 -2.21940863e-04]
   [-3.88712979e-05  1.48852137e-04  2.04777578e-04 ... -5.28344499e-05
    -3.42429965e-04 -2.93135992e-04]
   ...
   [-1.45276601e-04  1.42861991e-05  2.75106111e-04 ... -3.76468161e-05
    -2.59552762e-04 -2.05685734e-04]
   [-1.90249673e-04 -1.89188344e-04 -6.06206595e-06 ... -4.22182493e-05
    -5.07665100e-05  1.73006410e-05]
   [-1.66312515e-04 -1.57681716e-04  8.40802095e-05 ... -2.80631648e-04
     4.64846089e-05 -2.69745069e-05]]

  [[ 2.52315018e-04  1.37318755e-04  7.34694186e-06 ...  6.88106156e-05
    -2.59639084e-04 -5.23302515e-05]
   [ 1.09177468e-04  1.19801589e-04  1.20560479e-04 ... -3.13964119e-05
    -4.32244909e-04 -5.61552661e-05]
   [ 5.39667781e-05  9.82593701e-05  1.16698255e-04 ... -5.03267838e-05
    -2.68034113e-04 -2.03717238e-04]
   ...
   [-1.23629114e-04  2.08134879e-05  1.45863407e-04 ... -5.22632035e-05
    -2.75032769e-04 -2.37685512e-04]
   [-1.31681081e-04 -8.72807868e-05  3.63341242e-05 ...  3.34264623e-05
    -9.40441096e-05 -3.65532142e-05]
   [-8.05665040e-05 -1.35599810e-04  1.54671245e-04 ... -3.12842079e-04
     3.94195813e-05 -9.37813238e-05]]

  [[ 8.34726743e-05  2.34369931e-04 -6.86994463e-06 ...  9.92854621e-05
    -1.32194342e-04 -4.14272654e-05]
   [ 2.12691193e-05  2.02745723e-04  1.44213263e-05 ...  1.15128481e-04
    -2.03963034e-04 -2.29535217e-05]
   [-3.14584759e-05  2.34757754e-04 -5.31294791e-06 ...  6.15758472e-05
    -6.53850293e-05 -8.11161080e-05]
   ...
   [-1.62113894e-04  1.31983543e-04 -1.40095581e-05 ... -6.84586121e-05
    -3.55633529e-06 -2.58197979e-05]
   [-8.86888083e-05  4.30948567e-05  1.00263242e-05 ... -3.75453237e-06
     4.66970232e-05 -1.41085475e-05]
   [-5.73820544e-05 -3.65810120e-05  1.05833111e-04 ... -2.06670520e-04
     2.01888906e-05  6.76948039e-05]]]], shape=(1, 32, 38, 36), dtype=float32)
scores tf.Tensor([0.5000207  0.5000433  0.4999826  ... 0.5000217  0.50003844 0.49997947], shape=(10944,), dtype=float32)
rpn_bbox_pred tf.Tensor(
[[ 3.9444320e-05  6.9260743e-05  4.0039376e-05  1.3333066e-04]
 [-2.8608961e-04 -9.5688820e-06 -5.2532734e-05  1.4907753e-05]
 [-9.5204181e-05 -1.2283269e-04  3.5972198e-04  1.9759413e-05]
 ...
 [-1.5519327e-05  1.1782763e-04  1.6487727e-04  9.9818295e-05]
 [ 2.1546082e-04 -4.3670101e-05  9.9615791e-05  3.3442757e-05]
 [ 9.5570911e-05 -2.0667052e-04  2.0188891e-05  6.7694804e-05]], shape=(10944, 4), dtype=float32)
proposals tf.Tensor(
[[ -83.99643   -39.99975   100.01094    56.013046]
 [-176.09563   -88.00327   191.88506   103.999596]
 [-360.2025   -184.05096   376.06238   199.95663 ]
 ...
 [ 555.9914    416.01196   644.0059    592.02954 ]
 [ 512.0291    327.97873   688.0467    679.9905  ]
 [ 424.0301    151.83066   776.0372    855.8783  ]], shape=(10944, 4), dtype=float32)

這裏我們來看一下bbox_transform_inv_tf函數

def bbox_transform_inv_tf(boxes, deltas):
    '''
    根據anchor和偏移量得到一個改善後的anchor
    '''
    boxes = tf.cast(boxes, deltas.dtype)
    # 獲取所有Anchor預測框的寬、高和中心點
    widths = tf.subtract(boxes[:, 2], boxes[:, 0]) + 1.0
    heights = tf.subtract(boxes[:, 3], boxes[:, 1]) + 1.0
    ctr_x = tf.add(boxes[:, 0], widths * 0.5)
    ctr_y = tf.add(boxes[:, 1], heights * 0.5)
    # 經過1*1卷積核得到的預測邊框偏移值
    dx = deltas[:, 0]
    dy = deltas[:, 1]
    dw = deltas[:, 2]
    dh = deltas[:, 3]
    # 獲取改善後的預測邊框的中心值和寬高
    pred_ctr_x = tf.add(tf.multiply(dx, widths), ctr_x)
    pred_ctr_y = tf.add(tf.multiply(dy, heights), ctr_y)
    pred_w = tf.multiply(tf.exp(dw), widths)
    pred_h = tf.multiply(tf.exp(dh), heights)
    # 獲取改善後的預測邊框的左上和右下的座標點
    pred_boxes0 = tf.subtract(pred_ctr_x, pred_w * 0.5)
    pred_boxes1 = tf.subtract(pred_ctr_y, pred_h * 0.5)
    pred_boxes2 = tf.add(pred_ctr_x, pred_w * 0.5)
    pred_boxes3 = tf.add(pred_ctr_y, pred_h * 0.5)

    return tf.stack([pred_boxes0, pred_boxes1, pred_boxes2, pred_boxes3], axis=1)

由於改善後的預測邊框座標存在負值,需要修正到圖像內

region_proposal_network函數繼續

# 調整boxes的座標,使其全部在圖像的範圍內, 全部大於0,小於圖像寬高
proposals = clip_boxes_tf(proposals, im_info)
print('proposals', proposals)

運行結果

proposals tf.Tensor(
[[  0.         0.       100.012924  56.020565]
 [  0.         0.       192.09363  104.08067 ]
 [  0.         0.       375.88458  200.01782 ]
 ...
 [556.0016   415.98993  599.       499.      ]
 [512.0104   328.03198  599.       499.      ]
 [423.9528   151.9621   599.       499.      ]], shape=(10944, 4), dtype=float32)

這裏我們來看一下clip_boxes_tf函數

def clip_boxes_tf(boxes, im_info):
    '''
    調整boxes的座標,使其全部在圖像的範圍內,全部大於0,小於圖像寬高
    '''
    b0 = tf.maximum(tf.minimum(boxes[:, 0], im_info[1] - 1), 0)
    b1 = tf.maximum(tf.minimum(boxes[:, 1], im_info[0] - 1), 0)
    b2 = tf.maximum(tf.minimum(boxes[:, 2], im_info[1] - 1), 0)
    b3 = tf.maximum(tf.minimum(boxes[:, 3], im_info[0] - 1), 0)
    return tf.stack([b0, b1, b2, b3], axis=1)

region_proposal_network函數繼續

# 非極大值抑制,輸出的索引號
indices = tf.image.non_max_suppression(boxes=proposals,
                                       scores=scores,
                                       max_output_size=2000,
                                       iou_threshold=0.7)
print('indices', indices)
# 根據索引號輸出對應proposals
boxes = tf.gather(proposals, indices)
boxes = tf.cast(boxes, tf.float32)
# 根據索引號輸出對應分數
scores = tf.gather(scores, indices)
scores = tf.reshape(scores, shape=(-1, 1))

# Only support single image as input
batch_inds = tf.zeros((tf.shape(indices)[0], 1), dtype=tf.float32)
# 給輸出的proposals添加一個全爲0的維度
blob = tf.concat([batch_inds, boxes], 1)
print('blob', blob)

運行結果

indices tf.Tensor([8022 8337 8412 ... 6363 6336 7722], shape=(1595,), dtype=int32)
blob tf.Tensor(
[[  0.      216.15192 311.95346 344.23438 439.97708]
 [  0.      168.22734 327.92365 296.2519  456.01996]
 [  0.      315.90656 303.90228 404.00177 480.20178]
 ...
 [  0.      283.46658 247.80559 467.6185  343.86914]
 [  0.      235.49017 247.79012 419.61877 343.86444]
 [  0.      267.30078 311.71582 451.43933 407.80243]], shape=(1595, 5), dtype=float32)

關於非極大值抑制的內容可以參考Tensorflow的圖像操作 NMS(非極大值抑制算法),首先看看tf.image.non_max_suppression函數

參數:

  • boxes:形狀爲[num_boxes, 4]的二維浮點型Tensor.
  • scores:形狀爲[num_boxes]的一維浮點型Tensor,表示與每個框(每行框)對應的單個分數.
  • max_output_size:一個標量整數Tensor,表示通過非最大抑制選擇的框的最大數量.
  • iou_threshold:一種浮點數,表示判斷框是否相對於IOU重疊太多的閾值.

以分數降序來選擇邊界框的一個子集.

修剪與先前選擇的框重疊的高交集(IOU)的框.邊界框以[y1,x1,y2,x2]的形式提供,其中(y1,x1)和(y2,x2)是任何對角線盒對角點的座標,座標可以作爲標準化提供(即位於區間[0,1])或絕對.請注意,該算法對原點在座標系中的位置是不可知的.請注意,該算法對座標系的正交變換和平移不變;因此座標系的平移或反射導致算法選擇相同的框.此操作的輸出是一組整數索引, 用於表示選定框的邊界框的輸入集合.然後,可以使用tf.gather操作獲取與所選索引對應的邊界框座標。

 region_proposal_network函數繼續

# 準備分類和迴歸標籤和權重,用於訓練。
A = 9
# 10944,預測框的數量
total_anchors = tf.shape(anchors)[0]
# allow boxes to sit over the edge by a small amount
_allowed_border = 0

# 獲取feature map的高32和寬38
height = tf.shape(rpn_cls_score)[1]
width = tf.shape(rpn_cls_score)[2]
# 將預測框限制在原圖範圍內的預測框的索引號
# 此處只有2870個預測框完全在原圖範圍內
inds_inside = tf.reshape(tf.where(
    (anchors[:, 0] >= -_allowed_border) &
    (anchors[:, 1] >= -_allowed_border) &
    (anchors[:, 2] < (im_info[1] + _allowed_border)) &  # width
    (anchors[:, 3] < (im_info[0] + _allowed_border))  # height
), shape=(-1,))
print('inds_inside', inds_inside)
# 獲取在原圖範圍內的預測框
anchors = tf.gather(anchors, inds_inside)
print('anchors', anchors)

運行結果

inds_inside tf.Tensor([1080 1089 1098 ... 9495 9504 9513], shape=(2870,), dtype=int64)
anchors tf.Tensor(
[[ 12.   8. 195. 103.]
 [ 28.   8. 211. 103.]
 [ 44.   8. 227. 103.]
 ...
 [380. 392. 563. 487.]
 [396. 392. 579. 487.]
 [412. 392. 595. 487.]], shape=(2870, 4), dtype=float32)

接下來就要跟真實的label標註區域打交道了

# label: 1 is positive, 0 is negative, -1 is don't care
# labels初始全爲-1
labels = tf.zeros_like(inds_inside, dtype=tf.float32)
labels -= 1.
ones = tf.ones_like(inds_inside, dtype=tf.float32)
zeros = tf.zeros_like(inds_inside, dtype=tf.float32)
# overlaps between the anchors and the gt boxes
# 計算每一個預測邊框跟真實label邊框之間的IoU
overlaps = bbox_overlaps_tf(anchors, gt_boxes[:, :4])
print('overlaps', overlaps)
# 獲取每個預測邊框跟哪個標註邊框的IoU最大
argmax_overlaps = tf.cast(tf.argmax(overlaps, axis=1), dtype=tf.int32)
print('argmax_overlaps', argmax_overlaps)

運行結果

overlaps tf.Tensor(
[[0.05193141 0.02128767]
 [0.05481446 0.02483504]
 [0.05481446 0.02840714]
 ...
 [0.03245405 0.        ]
 [0.02597288 0.        ]
 [0.01957259 0.        ]], shape=(2870, 2), dtype=float32)
argmax_overlaps tf.Tensor([0 0 0 ... 0 0 0], shape=(2870,), dtype=int32)
max_overlaps tf.Tensor([0.05193141 0.05481446 0.05481446 ... 0.03245405 0.02597288 0.01957259], shape=(2870,), dtype=float32)

關於IoU的內容請參考Kaggle賽題分析(二) 中的U-Net網絡增強

這裏我們來看一下bbox_overlaps_tf函數

def bbox_overlaps_tf(boxlist1, boxlist2):
    """ 計算iou
    :param boxlist1: N*4
    :param boxlist2: M*4
    Returns: N*M
    """
    intersections = intersection(boxlist1, boxlist2)
    areas1 = area(boxlist1)
    areas2 = area(boxlist2)
    unions = (
        tf.expand_dims(areas1, 1) + tf.expand_dims(areas2, 0) - intersections)
    return tf.where(
        tf.equal(intersections, 0.0),
        tf.zeros_like(intersections), tf.truediv(intersections, unions))

def intersection(boxlist1, boxlist2):
    """計算box之間的交叉面積
    :param boxlist1: N*4
    :param boxlist2: M*4
    Returns: N*M
    """
    x_min1, y_min1, x_max1, y_max1 = tf.split(
        value=boxlist1, num_or_size_splits=4, axis=1)
    x_min2, y_min2, x_max2, y_max2 = tf.split(
        value=boxlist2, num_or_size_splits=4, axis=1)
    all_pairs_min_ymax = tf.minimum(y_max1, tf.transpose(y_max2))
    all_pairs_max_ymin = tf.maximum(y_min1, tf.transpose(y_min2))
    intersect_heights = tf.maximum(0.0, all_pairs_min_ymax - all_pairs_max_ymin + 1.)
    all_pairs_min_xmax = tf.minimum(x_max1, tf.transpose(x_max2))
    all_pairs_max_xmin = tf.maximum(x_min1, tf.transpose(x_min2))
    intersect_widths = tf.maximum(0.0, all_pairs_min_xmax - all_pairs_max_xmin + 1.)
    return intersect_heights * intersect_widths


def area(boxlist):
    """ 計算面積
    :param boxlist1: N*4.
    """
    x_min, y_min, x_max, y_max = tf.split(value=boxlist, num_or_size_splits=4, axis=1)
    return tf.squeeze((y_max - y_min + 1.) * (x_max - x_min + 1.), [1])

因爲我這裏只標註了兩個邊框,所以overlaps的最後一個維度就是2,分別表示該預測邊框跟第一個標註的邊框和第二個標註邊框的IoU。如果我標註了3個邊框的話,則overlaps的最後一個維度就是3。

  region_proposal_network函數繼續

# 獲取每個標註邊框跟哪個預測邊框的IoU最大
gt_argmax_overlaps = tf.cast(tf.argmax(overlaps, axis=0), dtype=tf.int32)
print('gt_argmax_overlaps', gt_argmax_overlaps)
# 給agt_argmax_overlaps按順序編上序號
max_overlaps_gather_nd_inds = tf.stack([gt_argmax_overlaps, tf.range(tf.shape(overlaps)[1])], axis=1)
print('max_overlaps_gather_nd_inds', max_overlaps_gather_nd_inds)
# 獲取每個標註邊框跟某個預測邊框最大的IoU值
gt_max_overlaps = tf.gather_nd(overlaps, max_overlaps_gather_nd_inds)
print('gt_max_overlaps', gt_max_overlaps)
# 獲取跟標註邊框最大IoU值的相等的預測邊框的序號
gt_argmax_overlaps = tf.where(overlaps == gt_max_overlaps)[:, 0]
print('gt_argmax_overlaps', gt_argmax_overlaps)
# 如果預測邊框最大的IoU值小於0.3則labels全部變爲0,否則保留原值
labels = tf.where(max_overlaps < 0.3, zeros, labels)
print(tf.where(labels < 0))

運行結果

gt_argmax_overlaps tf.Tensor([410 943], shape=(2,), dtype=int32)
max_overlaps_gather_nd_inds tf.Tensor(
[[410   0]
 [943   1]], shape=(2, 2), dtype=int32)
gt_max_overlaps tf.Tensor([0.30275452 0.7433778 ], shape=(2,), dtype=float32)
gt_argmax_overlaps tf.Tensor(
[ 410  415  420  425  535  540  545  550  660  665  670  675  793  799
  805  811  943  945  951  957  963 1095 1097 1103 1109 1115 1247 1249
 1255 1261 1267 1399 1401 1407 1413 1419 1551 1553 1559 1565 1571 1705
 1711 1717 1723 1857 1863 1869 1875 2009 2015 2021 2027 2153 2158 2163
 2168 2278 2283 2288 2293 2403 2408 2413 2418 2523 2527 2531 2535 2626
 2630 2634 2638], shape=(73,), dtype=int64)
tf.Tensor(
[[ 389]
 [ 393]
 [ 397]
 [ 399]
 [ 402]
 [ 404]
 [ 407]
 [ 410]
 [ 412]
 [ 415]
 [ 417]
 [ 420]
 [ 422]
 [ 425]
 [ 430]
 [ 514]
 [ 518]
 [ 521]
 [ 522]
 [ 524]
 [ 525]
 [ 526]
 [ 527]
 [ 529]
 [ 530]
 [ 531]
 [ 532]
 [ 535]
 [ 536]
 [ 537]
 [ 540]
 [ 542]
 [ 545]
 [ 547]
 [ 550]
 [ 552]
 [ 555]
 [ 639]
 [ 643]
 [ 646]
 [ 647]
 [ 649]
 [ 650]
 [ 651]
 [ 652]
 [ 654]
 [ 655]
 [ 656]
 [ 657]
 [ 660]
 [ 661]
 [ 662]
 [ 665]
 [ 667]
 [ 670]
 [ 672]
 [ 675]
 [ 677]
 [ 680]
 [ 682]
 [ 685]
 [ 764]
 [ 767]
 [ 769]
 [ 772]
 [ 774]
 [ 776]
 [ 777]
 [ 779]
 [ 780]
 [ 781]
 [ 782]
 [ 783]
 [ 785]
 [ 786]
 [ 787]
 [ 788]
 [ 789]
 [ 791]
 [ 793]
 [ 794]
 [ 795]
 [ 797]
 [ 799]
 [ 801]
 [ 803]
 [ 805]
 [ 807]
 [ 809]
 [ 811]
 [ 813]
 [ 815]
 [ 817]
 [ 819]
 [ 823]
 [ 829]
 [ 916]
 [ 919]
 [ 921]
 [ 924]
 [ 926]
 [ 928]
 [ 929]
 [ 931]
 [ 932]
 [ 933]
 [ 934]
 [ 935]
 [ 937]
 [ 938]
 [ 939]
 [ 940]
 [ 941]
 [ 943]
 [ 945]
 [ 946]
 [ 947]
 [ 949]
 [ 951]
 [ 953]
 [ 955]
 [ 957]
 [ 959]
 [ 961]
 [ 963]
 [ 965]
 [ 967]
 [ 969]
 [ 971]
 [ 973]
 [ 975]
 [ 977]
 [ 981]
 [1068]
 [1071]
 [1073]
 [1076]
 [1078]
 [1080]
 [1081]
 [1083]
 [1084]
 [1085]
 [1086]
 [1087]
 [1089]
 [1090]
 [1091]
 [1092]
 [1093]
 [1095]
 [1097]
 [1098]
 [1099]
 [1101]
 [1103]
 [1105]
 [1107]
 [1109]
 [1111]
 [1113]
 [1115]
 [1117]
 [1119]
 [1121]
 [1123]
 [1125]
 [1127]
 [1129]
 [1133]
 [1220]
 [1223]
 [1225]
 [1228]
 [1230]
 [1232]
 [1233]
 [1235]
 [1236]
 [1237]
 [1238]
 [1239]
 [1241]
 [1242]
 [1243]
 [1244]
 [1245]
 [1247]
 [1249]
 [1250]
 [1251]
 [1253]
 [1255]
 [1257]
 [1259]
 [1261]
 [1263]
 [1265]
 [1267]
 [1269]
 [1271]
 [1273]
 [1275]
 [1277]
 [1279]
 [1281]
 [1285]
 [1372]
 [1375]
 [1377]
 [1380]
 [1382]
 [1384]
 [1385]
 [1387]
 [1388]
 [1389]
 [1390]
 [1391]
 [1393]
 [1394]
 [1395]
 [1396]
 [1397]
 [1399]
 [1401]
 [1402]
 [1403]
 [1405]
 [1407]
 [1409]
 [1411]
 [1413]
 [1415]
 [1417]
 [1419]
 [1421]
 [1423]
 [1425]
 [1427]
 [1429]
 [1431]
 [1433]
 [1437]
 [1524]
 [1527]
 [1529]
 [1532]
 [1534]
 [1536]
 [1537]
 [1539]
 [1540]
 [1541]
 [1542]
 [1543]
 [1545]
 [1546]
 [1547]
 [1548]
 [1549]
 [1551]
 [1553]
 [1554]
 [1555]
 [1557]
 [1559]
 [1561]
 [1563]
 [1565]
 [1567]
 [1569]
 [1571]
 [1573]
 [1575]
 [1577]
 [1579]
 [1581]
 [1583]
 [1585]
 [1589]
 [1676]
 [1679]
 [1681]
 [1684]
 [1686]
 [1688]
 [1689]
 [1691]
 [1692]
 [1693]
 [1694]
 [1695]
 [1697]
 [1698]
 [1699]
 [1700]
 [1701]
 [1703]
 [1705]
 [1706]
 [1707]
 [1709]
 [1711]
 [1713]
 [1715]
 [1717]
 [1719]
 [1721]
 [1723]
 [1725]
 [1727]
 [1729]
 [1731]
 [1735]
 [1741]
 [1828]
 [1831]
 [1833]
 [1836]
 [1838]
 [1840]
 [1841]
 [1843]
 [1844]
 [1845]
 [1846]
 [1847]
 [1849]
 [1850]
 [1851]
 [1852]
 [1853]
 [1855]
 [1857]
 [1858]
 [1859]
 [1861]
 [1863]
 [1865]
 [1867]
 [1869]
 [1871]
 [1873]
 [1875]
 [1877]
 [1879]
 [1881]
 [1883]
 [1887]
 [1983]
 [1985]
 [1988]
 [1990]
 [1993]
 [1995]
 [1996]
 [1997]
 [1999]
 [2001]
 [2002]
 [2003]
 [2005]
 [2007]
 [2009]
 [2011]
 [2013]
 [2015]
 [2017]
 [2019]
 [2021]
 [2023]
 [2025]
 [2027]
 [2029]
 [2033]
 [2132]
 [2136]
 [2140]
 [2145]
 [2150]
 [2153]
 [2155]
 [2158]
 [2160]
 [2163]
 [2165]
 [2168]
 [2173]
 [2261]
 [2265]
 [2270]
 [2275]
 [2278]
 [2280]
 [2283]
 [2285]
 [2288]
 [2293]
 [2298]
 [2403]
 [2408]
 [2413]
 [2418]
 [2423]
 [2523]
 [2527]
 [2531]
 [2535]
 [2539]
 [2626]
 [2630]
 [2634]
 [2638]
 [2642]], shape=(415, 1), dtype=int64)

由結果可知,預測邊框最大的IoU值是有大於0.3的。

region_proposal_network函數繼續

# 獲取跟標註邊框最大IoU值的相等的預測邊框的序號的唯一值
unique_gt_argmax_overlaps = tf.unique(gt_argmax_overlaps)[0]
print('unique_gt_argmax_overlaps', unique_gt_argmax_overlaps)
# 獲取跟該序號唯一值相同序號的labels中的值,並取反
highest_fg_label = tf.gather(labels, unique_gt_argmax_overlaps) * -1.
print('highest_fg_label', highest_fg_label)
# 將該序號擴展一個維度
highest_gt_row_ids_expand_dim = tf.expand_dims(unique_gt_argmax_overlaps, axis=1)
print('highest_gt_row_ids_expand_dim', highest_gt_row_ids_expand_dim)
# 將highest_fg_label的值按照序號放回labels中
labels = tf.tensor_scatter_nd_update(labels, highest_gt_row_ids_expand_dim, highest_fg_label)
print(tf.where(labels > 0))
# 如果預測邊框最大的IoU值大於等於0.7,則labels全部變爲1,否則保留原值
labels = tf.where(max_overlaps >= 0.7, ones, labels)
print(tf.where(labels > 0))

運行結果

unique_gt_argmax_overlaps tf.Tensor(
[ 410  415  420  425  535  540  545  550  660  665  670  675  793  799
  805  811  943  945  951  957  963 1095 1097 1103 1109 1115 1247 1249
 1255 1261 1267 1399 1401 1407 1413 1419 1551 1553 1559 1565 1571 1705
 1711 1717 1723 1857 1863 1869 1875 2009 2015 2021 2027 2153 2158 2163
 2168 2278 2283 2288 2293 2403 2408 2413 2418 2523 2527 2531 2535 2626
 2630 2634 2638], shape=(73,), dtype=int64)
highest_fg_label tf.Tensor(
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1.], shape=(73,), dtype=float32)
highest_gt_row_ids_expand_dim tf.Tensor(
[[ 410]
 [ 415]
 [ 420]
 [ 425]
 [ 535]
 [ 540]
 [ 545]
 [ 550]
 [ 660]
 [ 665]
 [ 670]
 [ 675]
 [ 793]
 [ 799]
 [ 805]
 [ 811]
 [ 943]
 [ 945]
 [ 951]
 [ 957]
 [ 963]
 [1095]
 [1097]
 [1103]
 [1109]
 [1115]
 [1247]
 [1249]
 [1255]
 [1261]
 [1267]
 [1399]
 [1401]
 [1407]
 [1413]
 [1419]
 [1551]
 [1553]
 [1559]
 [1565]
 [1571]
 [1705]
 [1711]
 [1717]
 [1723]
 [1857]
 [1863]
 [1869]
 [1875]
 [2009]
 [2015]
 [2021]
 [2027]
 [2153]
 [2158]
 [2163]
 [2168]
 [2278]
 [2283]
 [2288]
 [2293]
 [2403]
 [2408]
 [2413]
 [2418]
 [2523]
 [2527]
 [2531]
 [2535]
 [2626]
 [2630]
 [2634]
 [2638]], shape=(73, 1), dtype=int64)
tf.Tensor(
[[ 410]
 [ 415]
 [ 420]
 [ 425]
 [ 535]
 [ 540]
 [ 545]
 [ 550]
 [ 660]
 [ 665]
 [ 670]
 [ 675]
 [ 793]
 [ 799]
 [ 805]
 [ 811]
 [ 943]
 [ 945]
 [ 951]
 [ 957]
 [ 963]
 [1095]
 [1097]
 [1103]
 [1109]
 [1115]
 [1247]
 [1249]
 [1255]
 [1261]
 [1267]
 [1399]
 [1401]
 [1407]
 [1413]
 [1419]
 [1551]
 [1553]
 [1559]
 [1565]
 [1571]
 [1705]
 [1711]
 [1717]
 [1723]
 [1857]
 [1863]
 [1869]
 [1875]
 [2009]
 [2015]
 [2021]
 [2027]
 [2153]
 [2158]
 [2163]
 [2168]
 [2278]
 [2283]
 [2288]
 [2293]
 [2403]
 [2408]
 [2413]
 [2418]
 [2523]
 [2527]
 [2531]
 [2535]
 [2626]
 [2630]
 [2634]
 [2638]], shape=(73, 1), dtype=int64)
tf.Tensor(
[[ 410]
 [ 415]
 [ 420]
 [ 425]
 [ 535]
 [ 540]
 [ 545]
 [ 550]
 [ 660]
 [ 665]
 [ 670]
 [ 675]
 [ 791]
 [ 793]
 [ 799]
 [ 805]
 [ 811]
 [ 937]
 [ 943]
 [ 945]
 [ 951]
 [ 957]
 [ 963]
 [1089]
 [1095]
 [1097]
 [1103]
 [1109]
 [1115]
 [1241]
 [1247]
 [1249]
 [1255]
 [1261]
 [1267]
 [1393]
 [1399]
 [1401]
 [1407]
 [1413]
 [1419]
 [1545]
 [1551]
 [1553]
 [1559]
 [1565]
 [1571]
 [1705]
 [1711]
 [1717]
 [1723]
 [1857]
 [1863]
 [1869]
 [1875]
 [2009]
 [2015]
 [2021]
 [2027]
 [2153]
 [2158]
 [2163]
 [2168]
 [2278]
 [2283]
 [2288]
 [2293]
 [2403]
 [2408]
 [2413]
 [2418]
 [2523]
 [2527]
 [2531]
 [2535]
 [2626]
 [2630]
 [2634]
 [2638]], shape=(79, 1), dtype=int64)

由結果可知,預測邊框的最大IoU中有73個值是大於0.3的,並且同標註邊框的最大IoU相同。還有6個預測邊框的最大IoU是大於等於0.7的。

region_proposal_network函數繼續

# subsample positive labels if we have too many
# 這裏256爲RPN網絡一次處理的batch_size
# 前景總數128
num_fg = int(0.5 * 256)
# 獲取正樣本(前景)序號
fg_inds = tf.reshape(tf.where(labels == 1), shape=(-1,))
print('fg_inds', fg_inds)
# 統計正樣本(labels中爲1的值)的數量
fg_inds_num = tf.shape(fg_inds)[0]
print('fg_inds_num', fg_inds_num)
# 檢測正樣本數是否大於前景總數128
fg_flag = tf.cast(fg_inds_num > num_fg, dtype=tf.float32)
print('fg_flag', fg_flag)
# 將大於前景總數128的正樣本變成負樣本(背景)返回labels,這裏只有79個正樣本遠遠小於128,所以沒有發生變化
labels = fg_flag * random_disable_labels(labels, fg_inds, fg_inds_num - num_fg) + \
         (1.0 - fg_flag) * labels
print(tf.where(labels > 0))

運行結果

fg_inds tf.Tensor(
[ 410  415  420  425  535  540  545  550  660  665  670  675  791  793
  799  805  811  937  943  945  951  957  963 1089 1095 1097 1103 1109
 1115 1241 1247 1249 1255 1261 1267 1393 1399 1401 1407 1413 1419 1545
 1551 1553 1559 1565 1571 1705 1711 1717 1723 1857 1863 1869 1875 2009
 2015 2021 2027 2153 2158 2163 2168 2278 2283 2288 2293 2403 2408 2413
 2418 2523 2527 2531 2535 2626 2630 2634 2638], shape=(79,), dtype=int64)
fg_inds_num tf.Tensor(79, shape=(), dtype=int32)
fg_flag tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(
[[ 410]
 [ 415]
 [ 420]
 [ 425]
 [ 535]
 [ 540]
 [ 545]
 [ 550]
 [ 660]
 [ 665]
 [ 670]
 [ 675]
 [ 791]
 [ 793]
 [ 799]
 [ 805]
 [ 811]
 [ 937]
 [ 943]
 [ 945]
 [ 951]
 [ 957]
 [ 963]
 [1089]
 [1095]
 [1097]
 [1103]
 [1109]
 [1115]
 [1241]
 [1247]
 [1249]
 [1255]
 [1261]
 [1267]
 [1393]
 [1399]
 [1401]
 [1407]
 [1413]
 [1419]
 [1545]
 [1551]
 [1553]
 [1559]
 [1565]
 [1571]
 [1705]
 [1711]
 [1717]
 [1723]
 [1857]
 [1863]
 [1869]
 [1875]
 [2009]
 [2015]
 [2021]
 [2027]
 [2153]
 [2158]
 [2163]
 [2168]
 [2278]
 [2283]
 [2288]
 [2293]
 [2403]
 [2408]
 [2413]
 [2418]
 [2523]
 [2527]
 [2531]
 [2535]
 [2626]
 [2630]
 [2634]
 [2638]], shape=(79, 1), dtype=int64)

由此我們可以知道,正樣本(前景)的統計有兩種方法,一是預測框本身的最大IoU值在0.7以上,二是測試框的最大Iou值與標註框的最大IoU相等。

這裏我們來看一下random_disable_labels函數

def random_disable_labels(labels_input, inds, disable_nums):
    # 將正/負樣本序號亂序
    shuffle_fg_inds = tf.random.shuffle(inds)
    # 獲取大於128的正/256的負樣本序號
    disable_inds = shuffle_fg_inds[:disable_nums]
    # 將大於128的正/256的負樣本序號擴充一個維度
    disable_inds_expand_dim = tf.expand_dims(disable_inds, axis=1)
    # 將大於128的正/256的負樣本序號變爲負樣本/非樣本(-1)
    neg_ones = tf.ones_like(disable_inds, dtype=tf.float32) * -1.
    # 將負樣本/非樣本填充回lables
    return tf.tensor_scatter_nd_update(labels_input, disable_inds_expand_dim, neg_ones)

然後是負樣本(背景),region_proposal_network函數繼續

# subsample negative labels if we have too many
# 背景總數
num_bg = 256 - tf.shape(tf.where(labels == 1))[0]
print('num_bg', num_bg)
# bg_inds = np.where(labels == 0)[0]
# 獲取負樣本(背景)序號
bg_inds = tf.reshape(tf.where(labels == 0), shape=(-1,))
print('bg_inds', bg_inds)
# 獲取負樣本的數量
bg_inds_num = tf.shape(bg_inds)[0]
print('bg_inds_num', bg_inds_num)
# 檢測負樣本數是否大於背景總數
bg_flag = tf.cast(bg_inds_num > num_bg, dtype=tf.float32)
print('bg_flag', bg_flag)
# 將大於背景總數256的負樣本變成非樣本返回labels,這裏有2455個負樣本遠遠大於256,所以
# 有2278個負樣本變成了非樣本
labels = bg_flag * random_disable_labels(labels, bg_inds, bg_inds_num - num_bg) + \
         (1.0 - bg_flag) * labels
print(tf.where(labels == 0))

運行結果

num_bg tf.Tensor(177, shape=(), dtype=int32)
bg_inds tf.Tensor([   0    1    2 ... 2867 2868 2869], shape=(2455,), dtype=int64)
bg_inds_num tf.Tensor(2455, shape=(), dtype=int32)
bg_flag tf.Tensor(1.0, shape=(), dtype=float32)
tf.Tensor(
[[  30]
 [  48]
 [  59]
 [  61]
 [  76]
 [  84]
 [ 114]
 [ 131]
 [ 150]
 [ 154]
 [ 172]
 [ 173]
 [ 193]
 [ 245]
 [ 248]
 [ 283]
 [ 290]
 [ 326]
 [ 327]
 [ 331]
 [ 338]
 [ 339]
 [ 350]
 [ 409]
 [ 439]
 [ 486]
 [ 496]
 [ 556]
 [ 601]
 [ 617]
 [ 629]
 [ 701]
 [ 702]
 [ 725]
 [ 726]
 [ 808]
 [ 831]
 [ 857]
 [ 864]
 [ 877]
 [ 909]
 [ 918]
 [ 962]
 [ 964]
 [1013]
 [1016]
 [1020]
 [1042]
 [1052]
 [1072]
 [1120]
 [1124]
 [1140]
 [1147]
 [1154]
 [1193]
 [1197]
 [1211]
 [1216]
 [1221]
 [1282]
 [1286]
 [1287]
 [1290]
 [1302]
 [1315]
 [1316]
 [1341]
 [1349]
 [1357]
 [1363]
 [1368]
 [1374]
 [1383]
 [1435]
 [1448]
 [1460]
 [1467]
 [1474]
 [1480]
 [1484]
 [1485]
 [1509]
 [1582]
 [1587]
 [1596]
 [1622]
 [1633]
 [1636]
 [1646]
 [1654]
 [1658]
 [1663]
 [1671]
 [1674]
 [1708]
 [1733]
 [1746]
 [1766]
 [1767]
 [1809]
 [1834]
 [1922]
 [1932]
 [1942]
 [1961]
 [1979]
 [2004]
 [2030]
 [2038]
 [2039]
 [2061]
 [2082]
 [2085]
 [2090]
 [2157]
 [2174]
 [2175]
 [2176]
 [2177]
 [2181]
 [2193]
 [2208]
 [2212]
 [2215]
 [2243]
 [2263]
 [2273]
 [2303]
 [2308]
 [2318]
 [2321]
 [2331]
 [2342]
 [2345]
 [2370]
 [2376]
 [2377]
 [2441]
 [2458]
 [2460]
 [2466]
 [2490]
 [2502]
 [2510]
 [2514]
 [2517]
 [2519]
 [2522]
 [2564]
 [2583]
 [2601]
 [2609]
 [2614]
 [2615]
 [2623]
 [2625]
 [2649]
 [2654]
 [2667]
 [2673]
 [2704]
 [2710]
 [2725]
 [2726]
 [2758]
 [2770]
 [2779]
 [2792]
 [2813]
 [2833]
 [2838]
 [2842]
 [2848]
 [2850]
 [2856]
 [2860]], shape=(177, 1), dtype=int64)

從結果可知,2455個背景預測框變成了177個。region_proposal_network函數繼續

# 將每一個預測框的IoU最大的標註框的座標值給提取出來
gt_overlaps = tf.gather(gt_boxes, argmax_overlaps, axis=0)[:, :4]
print('gt_overlaps', gt_overlaps)
# 此處將每個anchor與gt_box對準,gt_box與anchor的dx,dy,dw,dh,用來與模型預測的box計算損失
bbox_targets = bbox_transform_tf(anchors, gt_overlaps)
print('bbox_targets', bbox_targets)

運行結果

gt_overlaps tf.Tensor(
[[ 21.2  33.1 461.2 561.3]
 [ 21.2  33.1 461.2 561.3]
 [ 21.2  33.1 461.2 561.3]
 ...
 [ 21.2  33.1 461.2 561.3]
 [ 21.2  33.1 461.2 561.3]
 [ 21.2  33.1 461.2 561.3]], shape=(2870, 4), dtype=float32)
bbox_targets tf.Tensor(
[[ 0.7483696   2.5177085   0.87410915  1.7070184 ]
 [ 0.661413    2.5177085   0.87410915  1.7070184 ]
 [ 0.5744565   2.5177085   0.87410915  1.7070184 ]
 ...
 [-1.2516304  -1.4822916   0.87410915  1.7070183 ]
 [-1.3385869  -1.4822916   0.87410915  1.7070183 ]
 [-1.4255434  -1.4822916   0.87410915  1.7070183 ]], shape=(2870, 4), dtype=float32)

 這裏我們來看一下bbox_transform_tf函數

def bbox_transform_tf(ex_rois, gt_rois):
    # 獲取原圖範圍內預測框的寬、高、中心點
    ex_widths = ex_rois[:, 2] - ex_rois[:, 0] + 1.0
    ex_heights = ex_rois[:, 3] - ex_rois[:, 1] + 1.0
    ex_ctr_x = ex_rois[:, 0] + 0.5 * ex_widths
    ex_ctr_y = ex_rois[:, 1] + 0.5 * ex_heights
    # 獲取標註框的寬、高、中心點
    gt_widths = gt_rois[:, 2] - gt_rois[:, 0] + 1.0
    gt_heights = gt_rois[:, 3] - gt_rois[:, 1] + 1.0
    gt_ctr_x = gt_rois[:, 0] + 0.5 * gt_widths
    gt_ctr_y = gt_rois[:, 1] + 0.5 * gt_heights
    # 獲取目標框的中心點座標、寬、高
    targets_dx = (gt_ctr_x - ex_ctr_x) / ex_widths
    targets_dy = (gt_ctr_y - ex_ctr_y) / ex_heights
    targets_dw = tf.math.log(gt_widths / ex_widths)
    targets_dh = tf.math.log(gt_heights / ex_heights)

    targets = tf.stack([targets_dx, targets_dy, targets_dw, targets_dh], axis=1)
    return targets

region_proposal_network函數繼續

# 生成一個在原圖像內的預測框相同數量2870個的4維的全0矩陣
bbox_inside_weights = tf.zeros((tf.shape(inds_inside)[0], 4), dtype=tf.float32, name='bbox_inside_weights')
print('bbox_inside_weights', bbox_inside_weights)
# only the positive ones have regression targets
# 獲取所有正樣本的序號
bbox_inside_inds = tf.reshape(tf.where(labels == 1), shape=[-1, ])
print('bbox_inside_inds', bbox_inside_inds)
# 將bbox_inside_weights中的正樣本序號的值挑出來加1
bbox_inside_inds_weights = tf.gather(bbox_inside_weights, bbox_inside_inds) + (1.0, 1.0, 1.0, 1.0)
print('bbox_inside_inds_weights', bbox_inside_inds_weights)
# 將正樣本序號擴展一個維度
bbox_inside_inds_expand = tf.expand_dims(bbox_inside_inds, axis=1)
print('bbox_inside_inds_expand', bbox_inside_inds_expand)
# 將所有的正樣本的1反填入bbox_inside_weights的正樣本序號的位置上
bbox_inside_weights = tf.tensor_scatter_nd_update(bbox_inside_weights,
                                                  bbox_inside_inds_expand,
                                                  bbox_inside_inds_weights)
print(tf.unique(tf.where(bbox_inside_weights > 0)[:, 0])[0])

運行結果

bbox_inside_weights tf.Tensor(
[[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 ...
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]], shape=(2870, 4), dtype=float32)
bbox_inside_inds tf.Tensor(
[ 410  415  420  425  535  540  545  550  660  665  670  675  791  793
  799  805  811  937  943  945  951  957  963 1089 1095 1097 1103 1109
 1115 1241 1247 1249 1255 1261 1267 1393 1399 1401 1407 1413 1419 1545
 1551 1553 1559 1565 1571 1705 1711 1717 1723 1857 1863 1869 1875 2009
 2015 2021 2027 2153 2158 2163 2168 2278 2283 2288 2293 2403 2408 2413
 2418 2523 2527 2531 2535 2626 2630 2634 2638], shape=(79,), dtype=int64)
bbox_inside_inds_weights tf.Tensor(
[[1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]], shape=(79, 4), dtype=float32)
bbox_inside_inds_expand tf.Tensor(
[[ 410]
 [ 415]
 [ 420]
 [ 425]
 [ 535]
 [ 540]
 [ 545]
 [ 550]
 [ 660]
 [ 665]
 [ 670]
 [ 675]
 [ 791]
 [ 793]
 [ 799]
 [ 805]
 [ 811]
 [ 937]
 [ 943]
 [ 945]
 [ 951]
 [ 957]
 [ 963]
 [1089]
 [1095]
 [1097]
 [1103]
 [1109]
 [1115]
 [1241]
 [1247]
 [1249]
 [1255]
 [1261]
 [1267]
 [1393]
 [1399]
 [1401]
 [1407]
 [1413]
 [1419]
 [1545]
 [1551]
 [1553]
 [1559]
 [1565]
 [1571]
 [1705]
 [1711]
 [1717]
 [1723]
 [1857]
 [1863]
 [1869]
 [1875]
 [2009]
 [2015]
 [2021]
 [2027]
 [2153]
 [2158]
 [2163]
 [2168]
 [2278]
 [2283]
 [2288]
 [2293]
 [2403]
 [2408]
 [2413]
 [2418]
 [2523]
 [2527]
 [2531]
 [2535]
 [2626]
 [2630]
 [2634]
 [2638]], shape=(79, 1), dtype=int64)
tf.Tensor(
[ 410  415  420  425  535  540  545  550  660  665  670  675  791  793
  799  805  811  937  943  945  951  957  963 1089 1095 1097 1103 1109
 1115 1241 1247 1249 1255 1261 1267 1393 1399 1401 1407 1413 1419 1545
 1551 1553 1559 1565 1571 1705 1711 1717 1723 1857 1863 1869 1875 2009
 2015 2021 2027 2153 2158 2163 2168 2278 2283 2288 2293 2403 2408 2413
 2418 2523 2527 2531 2535 2626 2630 2634 2638], shape=(79,), dtype=int64)
# 生成一個在原圖像內的預測框相同數量2870個的4維的全0矩陣
bbox_outside_weights = tf.zeros((tf.shape(inds_inside)[0], 4), dtype=tf.float32, name='bbox_outside_weights')
rpn_positive_weight = -1
if rpn_positive_weight < 0:
    # uniform weighting of examples (given non-uniform sampling)
    # 獲取正負樣本的總數,這裏是256個
    num_examples = tf.reduce_sum(tf.cast(labels >= 0, dtype=tf.float32))
    print('num_examples', num_examples)
    # 初始化單個正負樣本的權重
    positive_weights = tf.ones((1, 4), dtype=tf.float32) / num_examples
    print('positive_weights', positive_weights)
    negative_weights = tf.ones((1, 4), dtype=tf.float32) / num_examples
    print('negative_weights', negative_weights)

else:
    assert ((rpn_positive_weight > 0) & (rpn_positive_weight < 1))
    positive_weights = rpn_positive_weight / tf.reduce_sum(tf.cast(labels == 1, dtype=tf.float32))
    negative_weights = (1.0 - rpn_positive_weight) / tf.reduce_sum(tf.cast(labels == 0, dtype=tf.float32))

運行結果

num_examples tf.Tensor(256.0, shape=(), dtype=float32)
positive_weights tf.Tensor([[0.00390625 0.00390625 0.00390625 0.00390625]], shape=(1, 4), dtype=float32)
negative_weights tf.Tensor([[0.00390625 0.00390625 0.00390625 0.00390625]], shape=(1, 4), dtype=float32)
# 正樣本的序號,79個
bbox_outside_positive_inds = bbox_inside_inds
# 負樣本的序號,177個
bbox_outside_negative_inds = tf.reshape(tf.where(labels == 0), shape=[-1, ])
# 將原圖中的預測邊框數量2870個4維0向量中提取正樣本序號的值,並賦上正樣本權重
bbox_outside_positive_inds_weights = tf.gather(bbox_outside_weights,
                                               bbox_outside_positive_inds) + positive_weights
# 將原圖中的預測邊框數量2870個4維0向量中提取負樣本序號的值,並賦上負樣本權重
bbox_outside_negative_inds_weights = tf.gather(bbox_outside_weights,
                                               bbox_outside_negative_inds) + negative_weights
# 給所有正樣本權重擴展一個維度
bbox_outside_positive_inds_expand = tf.expand_dims(bbox_outside_positive_inds, axis=1)
# 給所有負樣本權重擴展一個維度
bbox_outside_negative_inds_expand = tf.expand_dims(bbox_outside_negative_inds, axis=1)
# 將所有正樣本的權重值按照序號反寫回bbox_outside_weights
bbox_outside_weights = tf.tensor_scatter_nd_update(bbox_outside_weights,
                                                   bbox_outside_positive_inds_expand,
                                                   bbox_outside_positive_inds_weights)
# 將所有負樣本的權重值按照序號反寫回bbox_outside_weights
bbox_outside_weights = tf.tensor_scatter_nd_update(bbox_outside_weights,
                                                   bbox_outside_negative_inds_expand,
                                                   bbox_outside_negative_inds_weights)
print(tf.unique(tf.where(bbox_outside_weights > 0)[:, 0])[0])

運行結果

tf.Tensor(
[  13   16   22   35   36   51   56   83  111  120  144  155  184  214
  219  260  279  282  286  335  339  353  361  377  408  409  410  415
  420  425  426  439  454  472  476  485  494  507  520  528  533  535
  540  545  550  573  622  641  659  660  665  670  675  678  686  707
  710  718  728  729  755  759  791  793  799  805  811  865  875  912
  917  925  937  943  945  951  957  963  996 1002 1033 1037 1046 1065
 1089 1095 1097 1103 1104 1108 1109 1115 1116 1120 1162 1164 1167 1169
 1172 1180 1189 1202 1208 1217 1241 1247 1249 1255 1261 1267 1280 1295
 1312 1317 1324 1339 1346 1360 1361 1362 1376 1393 1399 1401 1407 1413
 1419 1447 1461 1463 1474 1475 1513 1522 1523 1545 1551 1553 1556 1559
 1565 1566 1571 1606 1608 1651 1667 1671 1680 1696 1705 1708 1711 1717
 1723 1734 1748 1750 1753 1763 1770 1788 1793 1800 1812 1814 1816 1829
 1854 1857 1863 1869 1875 1897 1942 1956 1960 1966 1973 1980 2009 2015
 2021 2027 2031 2043 2071 2079 2082 2092 2101 2110 2131 2147 2149 2153
 2158 2163 2168 2216 2242 2245 2249 2253 2267 2278 2283 2288 2293 2327
 2368 2382 2383 2403 2404 2405 2408 2413 2418 2466 2495 2496 2510 2519
 2523 2527 2531 2535 2551 2552 2586 2602 2603 2605 2608 2620 2626 2630
 2634 2638 2656 2660 2663 2693 2701 2702 2706 2721 2730 2755 2784 2791
 2803 2851 2861 2866], shape=(256,), dtype=int64)
# 這裏把上面處理完的目標anchors,labels,boxes,weights的size處理成一開始傳進來的大小
# 這裏是爲了保持輸入、輸出保持一致
# 處理labels
labels = unmap(labels, total_anchors, inds_inside, fill=-1, type='labels')
print(tf.unique(tf.where(labels >= 0)[:, 0])[0])
# 處理目標邊框
bbox_targets = unmap(bbox_targets, total_anchors, inds_inside, fill=0, type='bbox_targets')
print(tf.unique(tf.where(bbox_targets > 0)[:, 0])[0])
# 處理正樣本標記
bbox_inside_weights = unmap(bbox_inside_weights, total_anchors, inds_inside, fill=0,
                                 type='bbox_inside_weights')
print(tf.unique(tf.where(bbox_inside_weights > 0)[:, 0])[0])
# 處理正負樣本權重
bbox_outside_weights = unmap(bbox_outside_weights, total_anchors, inds_inside, fill=0,
                                  type='bbox_outside_weights')
print(tf.unique(tf.where(bbox_outside_weights > 0)[:, 0])[0])

# labels, 這裏reshape成anchor的數目, anchor的總數等於(原圖寬/16=32 * 原圖高/16=38 * 9)
# reshape labels
rpn_labels = tf.reshape(labels, (1, height, width, A))
# reshape目標邊框
rpn_bbox_targets = tf.reshape(bbox_targets, (1, height, width, A * 4), name='rpn_bbox_targets')
# reshape正樣本標記
rpn_bbox_inside_weights = tf.reshape(bbox_inside_weights, (1, height, width, A * 4),
                                     name='rpn_bbox_inside_weights')
# reshape正負樣本權重
rpn_bbox_outside_weights = tf.reshape(bbox_outside_weights, (1, height, width, A * 4),
                                      name='rpn_bbox_outside_weights')
rpn_labels = tf.cast(rpn_labels, dtype=tf.int32)

運行結果

tf.Tensor(
[1443 1488 1569 1623 1641 1749 1794 1836 1911 1923 2085 2100 2166 2188
 2244 2265 2319 2331 2337 2433 2454 2517 2539 2553 2613 2634 2661 2664
 2796 2838 2854 2863 2872 2877 2881 2983 2988 3196 3201 3205 3214 3223
 3258 3262 3277 3333 3334 3462 3507 3538 3547 3556 3558 3561 3565 3627
 3646 3699 3705 3723 3801 3846 3877 3880 3889 3898 3907 3928 3940 4035
 4048 4065 4137 4170 4188 4210 4219 4222 4231 4240 4249 4266 4299 4309
 4315 4341 4500 4509 4533 4542 4552 4561 4564 4572 4573 4582 4590 4591
 4629 4671 4672 4707 4839 4857 4863 4894 4903 4906 4915 4924 4933 4978
 5035 5236 5245 5248 5257 5266 5275 5337 5347 5349 5352 5412 5578 5586
 5587 5590 5599 5608 5617 5661 5679 5695 5733 5754 5766 5772 5784 5862
 5919 5932 5941 5950 5959 5997 6003 6100 6126 6274 6283 6285 6292 6300
 6301 6306 6318 6325 6351 6361 6369 6373 6408 6550 6559 6609 6616 6625
 6634 6643 6648 6682 6684 6697 6703 6745 6754 6775 6918 6924 6933 6936
 6958 6963 6966 6967 6976 6985 6988 7020 7047 7074 7269 7275 7300 7305
 7309 7318 7327 7330 7338 7375 7390 7392 7408 7429 7467 7633 7642 7651
 7656 7660 7669 7672 7698 7710 7732 7762 7767 7791 7806 7920 7953 7984
 7993 8002 8010 8011 8031 8052 8079 8115 8148 8157 8283 8304 8326 8335
 8343 8344 8353 8385 8397 8403 8670 8682 8727 8784 8817 9009 9063 9066
 9165 9297 9405 9459], shape=(256,), dtype=int64)
tf.Tensor([1080 1089 1098 ... 9495 9504 9513], shape=(2870,), dtype=int64)
tf.Tensor(
[2854 2863 2872 2881 3196 3205 3214 3223 3538 3547 3556 3565 3877 3880
 3889 3898 3907 4210 4219 4222 4231 4240 4249 4552 4561 4564 4573 4582
 4591 4894 4903 4906 4915 4924 4933 5236 5245 5248 5257 5266 5275 5578
 5587 5590 5599 5608 5617 5932 5941 5950 5959 6274 6283 6292 6301 6616
 6625 6634 6643 6958 6967 6976 6985 7300 7309 7318 7327 7642 7651 7660
 7669 7984 7993 8002 8011 8326 8335 8344 8353], shape=(79,), dtype=int64)
tf.Tensor(
[1443 1488 1569 1623 1641 1749 1794 1836 1911 1923 2085 2100 2166 2188
 2244 2265 2319 2331 2337 2433 2454 2517 2539 2553 2613 2634 2661 2664
 2796 2838 2854 2863 2872 2877 2881 2983 2988 3196 3201 3205 3214 3223
 3258 3262 3277 3333 3334 3462 3507 3538 3547 3556 3558 3561 3565 3627
 3646 3699 3705 3723 3801 3846 3877 3880 3889 3898 3907 3928 3940 4035
 4048 4065 4137 4170 4188 4210 4219 4222 4231 4240 4249 4266 4299 4309
 4315 4341 4500 4509 4533 4542 4552 4561 4564 4572 4573 4582 4590 4591
 4629 4671 4672 4707 4839 4857 4863 4894 4903 4906 4915 4924 4933 4978
 5035 5236 5245 5248 5257 5266 5275 5337 5347 5349 5352 5412 5578 5586
 5587 5590 5599 5608 5617 5661 5679 5695 5733 5754 5766 5772 5784 5862
 5919 5932 5941 5950 5959 5997 6003 6100 6126 6274 6283 6285 6292 6300
 6301 6306 6318 6325 6351 6361 6369 6373 6408 6550 6559 6609 6616 6625
 6634 6643 6648 6682 6684 6697 6703 6745 6754 6775 6918 6924 6933 6936
 6958 6963 6966 6967 6976 6985 6988 7020 7047 7074 7269 7275 7300 7305
 7309 7318 7327 7330 7338 7375 7390 7392 7408 7429 7467 7633 7642 7651
 7656 7660 7669 7672 7698 7710 7732 7762 7767 7791 7806 7920 7953 7984
 7993 8002 8010 8011 8031 8052 8079 8115 8148 8157 8283 8304 8326 8335
 8343 8344 8353 8385 8397 8403 8670 8682 8727 8784 8817 9009 9063 9066
 9165 9297 9405 9459], shape=(256,), dtype=int64)

這裏我們來看一下unmap函數

def unmap(data, count, inds, fill, type):
    """ 恢復原始預測框的尺寸 """
    if type == 'labels':
        # 構建一個anchor原始預測框數量10944個0的向量
        ret = tf.zeros((count,), dtype=tf.float32, name="unmap_" + type)
        # 將0轉成-1
        ret += fill
        # 將原圖的預測框的索引擴展一個維度
        inds_expand = tf.expand_dims(inds, axis=1)
        # 將labels中的正負樣本數據按照原圖範圍內的預測框索引位置放入到ret中
        return tf.tensor_scatter_nd_update(ret, inds_expand, data)
    else:
        # 構建一個anchor原始預測框數量10944個4維的0向量
        ret = tf.zeros(tf.concat([[count, ], tf.shape(data)[1:]], axis=0), dtype=tf.float32, name="unmap_" + type)
        ret += fill
        # 將原圖的預測框的索引擴展一個維度
        inds_expand = tf.expand_dims(inds, axis=1)
        # 將各種座標數據按照原圖範圍內的預測框索引位置放入到ret中
        return tf.tensor_scatter_nd_update(ret, inds_expand, data)

 region_proposal_network函數繼續

# roi採樣, 再基於roi[0,x1,y1,x2,y2]計算bbox_targets[dx,dy,dw,dh]
rpn_rois = blob
rpn_scores = scores
use_gt = False
# 前背景框總數
train_batch_size = 256
# 前景框因子
fg_fraction = 0.5
classes = ['__background__', 'bird', 'cat', 'cow', 'dog', 'horse', 'sheep', 'aeroplane',
           'bicycle', 'boat', 'bus', 'car', 'motorbike', 'train', 'bottle', 'chair',
           'diningtable', 'pottedplant', 'sofa', 'tvmonitor', 'person']
num_classes = len(classes)
# 這裏要等到rpn_labels所有的數據都執行到位了以後纔會繼續往下執行
with tf.control_dependencies([rpn_labels]):
    # Proposal ROIs (0, x1, y1, x2, y2) coming from RPN
    # (i.e., rpn.proposal_layer.ProposalLayer), or any other source
    # 獲取非極大值抑制的proposals(候選區域)和評分
    all_rois = rpn_rois
    all_scores = rpn_scores

    # Include ground-truth boxes in the set of candidate rois
    if use_gt:
        zeros = tf.zeros((tf.shape(gt_boxes)[0], 1), dtype=tf.float32)
        all_rois = tf.concat([all_rois, tf.concat([zeros, gt_boxes[:, :-1]], axis=1)], axis=0)
        # not sure if it a wise appending, but anyway i am not using it
        all_scores = tf.concat([all_scores, zeros], axis=0)

    num_images = 1
    rois_per_image = int(train_batch_size / num_images)
    fg_rois_per_image = int(fg_fraction * rois_per_image)

    # Sample rois with classification labels and bounding box regression
    # targets
    labels, rois, roi_scores, bbox_targets, bbox_inside_weights = sample_rois(
        all_rois=all_rois,
        all_scores=all_scores,
        gt_boxes=gt_boxes,
        fg_rois_per_image=fg_rois_per_image,
        rois_per_image=rois_per_image,
        num_classes=num_classes)

    rois = tf.reshape(rois, shape=(-1, 5))
    roi_scores = tf.reshape(roi_scores, shape=(-1,))
    labels = tf.reshape(labels, shape=(-1,))
    bbox_targets = tf.reshape(bbox_targets, shape=(-1, num_classes * 4))
    bbox_inside_weights = tf.reshape(bbox_inside_weights, shape=(-1, num_classes * 4))
    bbox_outside_weights = tf.cast(bbox_inside_weights > 0, dtype=tf.float32)

這裏我們需要重點看一下sample_rois函數

def sample_rois(all_rois, all_scores, gt_boxes, fg_rois_per_image, rois_per_image, num_classes):
    """
    生成包含前景和背景的隨機ROI樣本例子。
    """
    # overlaps: (rois x gt_boxes)
    # 計算極大值抑制的邊框和標註框的IoU
    overlaps = bbox_overlaps_tf(
        tf.cast(all_rois[:, 1:5], dtype=tf.float32),
        tf.cast(gt_boxes[:, :4], dtype=tf.float32))
    print('overlaps', overlaps)
    # 獲取每個極大值抑制的邊框跟哪個標註邊框的IoU最大
    gt_assignment = tf.argmax(overlaps, axis=1)
    print('gt_assignment', gt_assignment)
    # 獲取每一個極大值抑制邊框的最大IoU
    max_overlaps = backend.max(overlaps, axis=1)
    print('max_overlaps', max_overlaps)
    # 獲取極大值抑制邊框的最大IoU的分類標註
    labels = tf.reshape(tf.gather(gt_boxes, gt_assignment)[:, 4], shape=(-1,))
    print(len(tf.where(labels > 1)))

運行結果

overlaps tf.Tensor(
[[0.         0.        ]
 [0.         0.        ]
 [0.         0.        ]
 ...
 [0.07142365 0.00525129]
 [0.2026439  0.15127471]
 [0.         0.        ]], shape=(1526, 2), dtype=float32)
gt_assignment tf.Tensor([0 0 0 ... 0 0 0], shape=(1526,), dtype=int64)
max_overlaps tf.Tensor([0.         0.         0.         ... 0.07142365 0.2026439  0.        ], shape=(1526,), dtype=float32)
429

從結果可以看出1526箇中有429個class=2的標註類。

# 這裏將IoU大於等於0.5的極大值抑制邊框設爲前景,並獲取前景序號
fg_inds = tf.reshape(tf.where(max_overlaps >= 0.5), shape=(-1,))
print('fg_inds', fg_inds)
# 這裏將IoU大於等於0.1小於0.5的極大值抑制邊框設爲背景,並獲取背景序號
bg_inds = tf.reshape(tf.where((max_overlaps < 0.5) &
                              (max_overlaps >= 0.1)),
                     shape=(-1,))
print('bg_inds', bg_inds)

運行結果

fg_inds tf.Tensor(
[  23   46  131  426  437  457  478  489  631  633  722  737  751  780
  811  813  863  888  937 1198 1247 1298 1319 1423], shape=(24,), dtype=int64)
bg_inds tf.Tensor(
[  11   14   20   28   31   36   38   41   42   44   66   68   72   77
   83   88   92  103  106  108  115  118  127  130  137  139  149  151
  153  155  163  164  173  179  182  185  186  191  192  193  196  197
  198  199  207  217  219  221  222  223  224  230  238  239  242  244
  245  246  249  258  261  266  267  268  269  271  272  276  284  285
  286  288  290  292  297  299  306  308  310  311  314  321  323  324
  327  331  332  335  337  339  340  341  343  344  346  354  357  360
  363  364  366  370  371  375  378  385  386  391  392  393  394  396
  397  399  402  406  407  408  410  411  413  415  416  417  419  425
  427  428  430  434  438  439  440  441  443  444  452  453  454  455
  462  467  468  470  471  472  473  474  476  484  485  491  493  494
  499  501  504  505  506  509  513  518  523  525  527  528  529  531
  533  534  537  540  541  542  544  545  549  550  553  554  557  559
  561  563  564  570  579  580  582  583  584  588  596  604  605  611
  615  618  619  622  623  625  626  632  638  641  645  646  651  659
  660  661  666  667  670  671  674  675  678  680  681  682  684  687
  695  698  701  716  718  720  721  723  729  736  741  745  759  766
  776  778  782  786  787  788  798  803  806  808  810  812  816  818
  824  829  833  834  838  839  842  847  849  852  854  859  860  861
  864  867  871  874  878  881  882  892  894  898  900  908  909  916
  921  926  929  930  931  936  939  942  959  966  967  969  971  972
  974  975  976  978  979  980  982  986  987  989  992  993  994  997
  998  999 1001 1002 1007 1008 1010 1011 1015 1017 1020 1022 1028 1033
 1037 1046 1048 1050 1052 1055 1057 1058 1063 1064 1067 1073 1075 1076
 1077 1079 1080 1082 1083 1084 1089 1090 1091 1094 1101 1102 1108 1113
 1115 1117 1118 1122 1129 1130 1131 1132 1133 1134 1135 1136 1138 1139
 1140 1141 1144 1146 1147 1148 1150 1151 1152 1154 1159 1160 1165 1167
 1169 1172 1174 1178 1180 1182 1184 1186 1187 1188 1189 1190 1192 1194
 1197 1200 1204 1205 1207 1209 1210 1214 1221 1224 1225 1227 1230 1233
 1234 1237 1242 1248 1252 1256 1261 1263 1264 1266 1267 1268 1272 1278
 1280 1281 1282 1283 1284 1285 1287 1289 1292 1293 1296 1297 1301 1306
 1308 1309 1310 1311 1312 1313 1315 1316 1317 1320 1321 1322 1324 1327
 1328 1330 1332 1333 1338 1346 1347 1349 1351 1354 1355 1356 1358 1359
 1360 1370 1371 1372 1373 1375 1376 1380 1382 1383 1386 1389 1390 1391
 1392 1395 1399 1400 1402 1403 1406 1407 1408 1409 1411 1413 1414 1415
 1419 1424 1425 1434 1436 1446 1448 1451 1455 1457 1459 1460 1461 1462
 1465 1468 1470 1472 1474 1476 1477 1478 1480 1483 1487 1488 1490 1494
 1498 1501 1504 1507 1510 1514 1515 1518 1522 1524], shape=(528,), dtype=int64)
 # 獲取前景框的數量
fg_inds_nums = tf.shape(fg_inds)[0]
print('fg_inds_nums', fg_inds_nums)

# fg samples
# 將每一幅圖像的前景框數限制爲128,並與極大值抑制中的前景框對比,取較少的那個
fg_rois_nums = tf.minimum(fg_rois_per_image, fg_inds_nums)
print('fg_rois_nums', fg_rois_nums)
# 將前景框序號亂序
fg_inds_selected = tf.random.shuffle(fg_inds)[:fg_rois_nums]
print('fg_inds_selected', fg_inds_selected)
# 將每一幅圖像的前背景框總數限制爲256,並減去前景框數量獲得背景框數量
bg_rois_nums = rois_per_image - fg_rois_nums
print('bg_rois_nums', bg_rois_nums)
# 將背景框序號亂序
bg_inds_selected = tf.random.shuffle(bg_inds)[:bg_rois_nums]
print('bg_inds_selected', bg_inds_selected)
# 將亂序後的前背景框序號拼接
fg_bg_sample_inds = tf.concat([fg_inds_selected, bg_inds_selected], axis=0)
print('fg_bg_sample_inds', fg_bg_sample_inds)
# 獲取實際前背景框總數
samples_nums = tf.shape(fg_bg_sample_inds)[0]
print('samples_nums', samples_nums)
# 獲取非景框(即不是前景也不是背景)數量
lack_nums = rois_per_image - samples_nums
print('lack_nums', lack_nums)
# tf.random.categorical是根據logits的概率來採集數據,採集num_samples次
# 誰的概率大,採集到的可能性就高
lack_samples_inds = tf.reshape(tf.random.categorical(
    logits=tf.expand_dims(tf.zeros(samples_nums, name="lack_samples_inds"), axis=0),
    num_samples=lack_nums), shape=(-1,))
print('lack_samples_inds', lack_samples_inds)
# 將前背景框按照lack_samples_inds序號進行提取
lack_samples_inds = tf.gather(fg_bg_sample_inds, lack_samples_inds)
print('lack_samples_inds', lack_samples_inds)
# 拼接前背景框序號與非景框序號
total_samples_inds = tf.concat([fg_bg_sample_inds, lack_samples_inds], axis=0)
print('total_samples_inds', total_samples_inds)

運行結果

fg_inds_nums tf.Tensor(25, shape=(), dtype=int32)
fg_rois_nums tf.Tensor(25, shape=(), dtype=int32)
fg_inds_selected tf.Tensor(
[1468 1134  931 1473  386  810 1208  356  252 1488  192 1341  195 1170
  354 1124  786  414 1021 1502  932 1387 1417  923 1419], shape=(25,), dtype=int64)
bg_rois_nums tf.Tensor(231, shape=(), dtype=int32)
bg_inds_selected tf.Tensor(
[ 148  218 1054 1519   44  400  284 1183 1119  293 1002 1485 1159  651
  509  922  637  818 1469  916  138  905  179 1188  785  873 1106 1497
 1351 1500  558 1273 1498 1213  764  862  145  171 1096  694  556 1248
  116 1274  859  153 1080  990 1313  908 1057 1297 1058  166 1477 1462
 1402  294 1444 1152 1122  351 1481 1089 1167 1108  495 1066 1484 1236
   68 1463  176 1035  884  237 1140  725 1128  594 1489  235  736  788
 1247  525    0   79  244  856  696  155 1441  737  319 1154 1209 1015
  369  227 1131  529  540  249  920 1052  635  248 1499 1244 1451 1456
  853 1205  474  317  546  219  125  459  267  142  241 1516  863  150
  170 1508  130 1136  747  325 1179 1085  250 1214 1104  127  811 1301
 1268 1383  373  766 1482 1189  576  895 1203 1264 1172  672  739 1356
 1223 1091  303  960  482 1038  270 1153  262  147  224  310  887 1398
  508  487 1438 1030 1232  338 1149 1251  879 1090  567  970 1437 1102
  573 1053 1114 1425  189  220 1412 1505 1036 1397  939 1044 1198 1047
  311 1156 1518 1180  101 1354  332 1491 1343  995  964  329  829 1515
 1110 1266 1514  874  498   24 1150  805 1315 1384  112  121  893 1312
 1012 1320 1380 1155  918  593  247], shape=(231,), dtype=int64)
fg_bg_sample_inds tf.Tensor(
[1468 1134  931 1473  386  810 1208  356  252 1488  192 1341  195 1170
  354 1124  786  414 1021 1502  932 1387 1417  923 1419  148  218 1054
 1519   44  400  284 1183 1119  293 1002 1485 1159  651  509  922  637
  818 1469  916  138  905  179 1188  785  873 1106 1497 1351 1500  558
 1273 1498 1213  764  862  145  171 1096  694  556 1248  116 1274  859
  153 1080  990 1313  908 1057 1297 1058  166 1477 1462 1402  294 1444
 1152 1122  351 1481 1089 1167 1108  495 1066 1484 1236   68 1463  176
 1035  884  237 1140  725 1128  594 1489  235  736  788 1247  525    0
   79  244  856  696  155 1441  737  319 1154 1209 1015  369  227 1131
  529  540  249  920 1052  635  248 1499 1244 1451 1456  853 1205  474
  317  546  219  125  459  267  142  241 1516  863  150  170 1508  130
 1136  747  325 1179 1085  250 1214 1104  127  811 1301 1268 1383  373
  766 1482 1189  576  895 1203 1264 1172  672  739 1356 1223 1091  303
  960  482 1038  270 1153  262  147  224  310  887 1398  508  487 1438
 1030 1232  338 1149 1251  879 1090  567  970 1437 1102  573 1053 1114
 1425  189  220 1412 1505 1036 1397  939 1044 1198 1047  311 1156 1518
 1180  101 1354  332 1491 1343  995  964  329  829 1515 1110 1266 1514
  874  498   24 1150  805 1315 1384  112  121  893 1312 1012 1320 1380
 1155  918  593  247], shape=(256,), dtype=int64)
samples_nums tf.Tensor(256, shape=(), dtype=int32)
lack_nums tf.Tensor(0, shape=(), dtype=int32)
lack_samples_inds tf.Tensor([], shape=(0,), dtype=int64)
lack_samples_inds tf.Tensor([], shape=(0,), dtype=int64)
total_samples_inds tf.Tensor(
[1468 1134  931 1473  386  810 1208  356  252 1488  192 1341  195 1170
  354 1124  786  414 1021 1502  932 1387 1417  923 1419  148  218 1054
 1519   44  400  284 1183 1119  293 1002 1485 1159  651  509  922  637
  818 1469  916  138  905  179 1188  785  873 1106 1497 1351 1500  558
 1273 1498 1213  764  862  145  171 1096  694  556 1248  116 1274  859
  153 1080  990 1313  908 1057 1297 1058  166 1477 1462 1402  294 1444
 1152 1122  351 1481 1089 1167 1108  495 1066 1484 1236   68 1463  176
 1035  884  237 1140  725 1128  594 1489  235  736  788 1247  525    0
   79  244  856  696  155 1441  737  319 1154 1209 1015  369  227 1131
  529  540  249  920 1052  635  248 1499 1244 1451 1456  853 1205  474
  317  546  219  125  459  267  142  241 1516  863  150  170 1508  130
 1136  747  325 1179 1085  250 1214 1104  127  811 1301 1268 1383  373
  766 1482 1189  576  895 1203 1264 1172  672  739 1356 1223 1091  303
  960  482 1038  270 1153  262  147  224  310  887 1398  508  487 1438
 1030 1232  338 1149 1251  879 1090  567  970 1437 1102  573 1053 1114
 1425  189  220 1412 1505 1036 1397  939 1044 1198 1047  311 1156 1518
 1180  101 1354  332 1491 1343  995  964  329  829 1515 1110 1266 1514
  874  498   24 1150  805 1315 1384  112  121  893 1312 1012 1320 1380
 1155  918  593  247], shape=(256,), dtype=int64)

這裏有一點需要注意,極大值抑制邊框每次輸出的數量都不太一樣,所以打印出來的結果也不同。

 # 構建一個背景序號相同形狀的全0背景標籤
bg_labels = tf.zeros_like(final_bg_inds, dtype=tf.float32)
# 將背景標籤序號擴展一個維度
labels_bg_inds_expand = tf.expand_dims(final_bg_inds, axis=1)
# 將背景標籤按照背景標籤序號反寫回分類標註
labels = tf.tensor_scatter_nd_update(labels, labels_bg_inds_expand, bg_labels)
print('labels', labels)
# 將非極大值抑制的proposals按照前後背景框和非景框的序號(256個)進行提取
rois = tf.gather(all_rois, total_samples_inds)
print('rois', rois)
# 將非極大值抑制的評分按照前後背景框和非景框的序號(256個)進行提取
roi_scores = tf.gather(all_scores, total_samples_inds)
print('roi_scores', roi_scores)
# 將極大值抑制的邊框跟哪個標註邊框的IoU最大的標註邊框序號按照前後背景框和非景框的序號(256個)進行提取
gt_boxes_inds = tf.gather(gt_assignment, total_samples_inds)
print('gt_boxes_inds', gt_boxes_inds)
# 將提取後的proposals和標註框生成目標框(帶分類標註)
bbox_target_data = compute_targets(
    ex_rois=rois[:, 1:5],
    gt_rois=tf.gather(gt_boxes, gt_boxes_inds)[:, :4],
    labels=labels)
print('bbox_target_data', bbox_target_data)

運行結果

labels tf.Tensor(
[2. 1. 1. 1. 1. 2. 2. 2. 2. 1. 2. 2. 1. 1. 1. 2. 2. 1. 1. 2. 2. 1. 2. 2.
 1. 2. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.], shape=(256,), dtype=float32)
rois tf.Tensor(
[[  0.        87.73878   71.483185 343.38797  328.309   ]
 [  0.       119.943924   0.       472.45416  455.47726 ]
 [  0.        87.866516 152.07315  439.9245   499.      ]
 ...
 [  0.        24.023071 151.93892  151.91214  280.0304  ]
 [  0.       219.73932  192.06189  307.90833  368.234   ]
 [  0.       112.21364  248.26813  479.9464   440.0374  ]], shape=(256, 5), dtype=float32)
roi_scores tf.Tensor(
[[0.49930605]
 [0.49899176]
 [0.4998094 ]
 [0.50078213]
 [0.49922255]
 [0.5008811 ]
 [0.500855  ]
 [0.4999068 ]
 [0.5006265 ]
 [0.50016016]
 [0.49967363]
 [0.49936143]
 [0.5004159 ]
 [0.50048083]
 [0.500567  ]
 [0.49960896]
 [0.49965763]
 [0.49891794]
 [0.50058854]
 [0.49949652]
 [0.50067943]
 [0.5004672 ]
 [0.5008285 ]
 [0.5007335 ]
 [0.50038815]
 [0.5006373 ]
 [0.50013715]
 [0.50041246]
 [0.49971274]
 [0.5001378 ]
 [0.5001666 ]
 [0.499923  ]
 [0.500278  ]
 [0.5001397 ]
 [0.50009996]
 [0.4999674 ]
 [0.49966696]
 [0.5004775 ]
 [0.49977943]
 [0.5005407 ]
 [0.49933553]
 [0.49965364]
 [0.49919152]
 [0.5003087 ]
 [0.49969524]
 [0.50025797]
 [0.49954724]
 [0.5000017 ]
 [0.50023586]
 [0.50079584]
 [0.50040305]
 [0.5003251 ]
 [0.4998315 ]
 [0.5000335 ]
 [0.49963558]
 [0.49961445]
 [0.5003741 ]
 [0.50002867]
 [0.5000135 ]
 [0.49909943]
 [0.50092566]
 [0.49955425]
 [0.49969003]
 [0.50016105]
 [0.49999413]
 [0.49998608]
 [0.4996309 ]
 [0.5006923 ]
 [0.4992457 ]
 [0.50010973]
 [0.5003303 ]
 [0.49860814]
 [0.49921677]
 [0.4994349 ]
 [0.49987713]
 [0.49983975]
 [0.4996965 ]
 [0.4998196 ]
 [0.50081396]
 [0.5003692 ]
 [0.49965516]
 [0.5002867 ]
 [0.49994773]
 [0.4995782 ]
 [0.50018597]
 [0.5003698 ]
 [0.49960107]
 [0.5004119 ]
 [0.49988782]
 [0.50052005]
 [0.49963364]
 [0.50028425]
 [0.49980932]
 [0.500316  ]
 [0.49917495]
 [0.49966818]
 [0.50020355]
 [0.49922436]
 [0.4998039 ]
 [0.49945158]
 [0.50058955]
 [0.50046855]
 [0.5009947 ]
 [0.49988025]
 [0.49995822]
 [0.49991602]
 [0.49972185]
 [0.4998427 ]
 [0.49969718]
 [0.5003703 ]
 [0.49982855]
 [0.50040054]
 [0.50041145]
 [0.49982548]
 [0.50002795]
 [0.50030303]
 [0.5009165 ]
 [0.50074977]
 [0.4996151 ]
 [0.4999139 ]
 [0.49976528]
 [0.4998251 ]
 [0.4995345 ]
 [0.50009197]
 [0.4998482 ]
 [0.5003186 ]
 [0.49942216]
 [0.50086457]
 [0.49990255]
 [0.50043935]
 [0.5002069 ]
 [0.50029546]
 [0.49971247]
 [0.49996173]
 [0.5000769 ]
 [0.49963284]
 [0.5003248 ]
 [0.49970955]
 [0.49971664]
 [0.50030905]
 [0.5008899 ]
 [0.50003475]
 [0.50043875]
 [0.49984676]
 [0.5003003 ]
 [0.5001555 ]
 [0.50100607]
 [0.49972388]
 [0.49985737]
 [0.4998918 ]
 [0.50054806]
 [0.4996141 ]
 [0.49911308]
 [0.4998709 ]
 [0.49965635]
 [0.500453  ]
 [0.5000294 ]
 [0.49936455]
 [0.4995669 ]
 [0.5007801 ]
 [0.49937472]
 [0.50025356]
 [0.50062895]
 [0.50103045]
 [0.5002717 ]
 [0.5002477 ]
 [0.5000469 ]
 [0.50033987]
 [0.50001603]
 [0.49967104]
 [0.5008208 ]
 [0.50023985]
 [0.5009931 ]
 [0.49981537]
 [0.4995633 ]
 [0.49976498]
 [0.4995851 ]
 [0.49983037]
 [0.49980605]
 [0.4999846 ]
 [0.50032395]
 [0.4992842 ]
 [0.4999961 ]
 [0.49951896]
 [0.49983808]
 [0.500997  ]
 [0.50009537]
 [0.49967426]
 [0.49984682]
 [0.5002668 ]
 [0.49987146]
 [0.49924347]
 [0.5004875 ]
 [0.500055  ]
 [0.49986565]
 [0.5001018 ]
 [0.5000276 ]
 [0.4996256 ]
 [0.4995221 ]
 [0.5001933 ]
 [0.5002657 ]
 [0.5012568 ]
 [0.50037634]
 [0.4997889 ]
 [0.4995863 ]
 [0.4997971 ]
 [0.49964944]
 [0.4999195 ]
 [0.5000146 ]
 [0.5002807 ]
 [0.500206  ]
 [0.49956107]
 [0.5000731 ]
 [0.4991624 ]
 [0.5004191 ]
 [0.4999311 ]
 [0.49965027]
 [0.49986038]
 [0.49983898]
 [0.5000175 ]
 [0.4999176 ]
 [0.4998784 ]
 [0.50021714]
 [0.49988425]
 [0.5002642 ]
 [0.5001255 ]
 [0.50034624]
 [0.49994233]
 [0.49979162]
 [0.5005239 ]
 [0.4998375 ]
 [0.49986005]
 [0.49955866]
 [0.49969354]
 [0.5007337 ]
 [0.4996154 ]
 [0.49927685]
 [0.50034404]
 [0.5005275 ]
 [0.49940822]
 [0.49961078]
 [0.4992394 ]
 [0.49975425]
 [0.49922103]
 [0.49998415]
 [0.50038314]
 [0.49955294]
 [0.5005984 ]
 [0.5000376 ]
 [0.5001361 ]
 [0.49999878]
 [0.49961486]
 [0.49979705]
 [0.5002752 ]
 [0.49957287]
 [0.5000685 ]], shape=(256, 1), dtype=float32)
gt_boxes_inds tf.Tensor(
[1 0 0 0 0 1 1 1 1 0 1 1 0 0 0 1 1 0 0 1 1 0 1 1 0 1 0 0 1 0 1 1 0 0 0 0 1
 0 1 0 1 1 0 1 0 1 0 0 0 1 1 0 1 1 1 0 1 0 0 1 0 1 1 0 0 0 1 0 0 0 1 0 1 0
 1 1 1 0 0 1 1 1 0 1 0 1 1 1 1 1 1 1 0 1 0 1 0 0 1 1 0 0 0 1 1 1 1 1 1 1 0
 1 1 1 0 1 1 0 0 1 0 1 0 0 1 1 0 0 1 0 1 1 1 0 0 1 1 1 1 1 0 0 0 0 1 1 0 1
 1 1 1 1 0 1 0 0 0 1 0 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 0 1 1 1 1 0 0 1 1 1
 0 0 1 1 1 1 0 0 1 1 1 0 1 1 0 1 0 1 1 1 0 1 1 0 1 0 1 0 0 1 1 1 1 1 0 1 1
 0 1 1 0 1 1 1 0 1 1 1 1 1 1 0 1 0 0 1 0 1 0 0 1 1 0 0 1 1 1 1 1 1 0], shape=(256,), dtype=int64)
bbox_target_data tf.Tensor(
[[ 2.         -1.5203252   1.1220162  -1.4743905   0.42443678]
 [ 2.          1.5926099  -0.7512477  -1.4726366   0.4250237 ]
 [ 2.          0.51282847  0.34382617  0.3915599  -1.1635339 ]
 ...
 [ 0.         -1.2433872  -2.0211256   0.16303922  5.304551  ]
 [ 0.         -1.243926   11.182966    0.1625644   5.3030634 ]
 [ 0.         -5.7877192   4.03606     4.5598125   2.3695912 ]], shape=(256, 5), dtype=float32)

這裏我們來看一下compute_targets函數

def compute_targets(ex_rois, gt_rois, labels):
    """
    計算一副圖像的Bounding box迴歸的目標(帶分類標籤)
    """
    # 將提取後的極大值抑制邊框和標註邊框生成目標邊框
    targets = bbox_transform_tf(ex_rois, gt_rois)
    train_bbox_normalize_targets_precomputed = True
    if train_bbox_normalize_targets_precomputed:
        # Optionally normalize targets by a precomputed mean and stdev
        targets = (targets - (0.0, 0.0, 0.0, 0.0)) / (0.1, 0.1, 0.2, 0.2)
    labels_expand = tf.expand_dims(labels, axis=1)
    targets_add_labels = tf.concat([labels_expand, targets], axis=1)
    return targets_add_labels

繼續sample_rois函數

bbox_targets, bbox_inside_weights = get_bbox_regression_labels(bbox_target_data=bbox_target_data,
                                                               num_classes=num_classes)

return labels, rois, roi_scores, bbox_targets, bbox_inside_weights

這裏get_bbox_regression_labels函數如下

def get_bbox_regression_labels(bbox_target_data, num_classes):
    """邊界框迴歸目標(bbox_目標_數據)存儲在
        緊湊表格NX(類別、tx、ty、tw、th)
        此函數將這些目標擴展爲所使用的4/4*K表示
        通過網絡(即只有一個類具有非零目標)。

    Returns:
        bbox_target (ndarray): N x 4K blob of regression targets
        bbox_inside_weights (ndarray): N x 4K blob of loss weights
    """
    # 獲取目標框所有的分類
    clss = bbox_target_data[:, 0]
    print('clss', clss)
    # 初始化一個全爲0的迴歸目標框(256,4*21),256是前背景目標框的總數,21是需要分類的總類別數
    bbox_targets = tf.zeros((tf.shape(clss)[0], 4 * num_classes), dtype=tf.float32, name="regression_bbox_targets")
    print('bbox_targets', bbox_targets)
    # 初始化一個全爲0的迴歸權重值
    bbox_inside_weights = tf.zeros_like(bbox_targets, dtype=tf.float32)
    print('bbox_inside_weights', bbox_inside_weights)
    # 獲取所有分類非背景座標索引
    inds = tf.cast(tf.reshape(tf.where(clss > 0), shape=(-1,)), dtype=tf.int32)
    print('inds', inds)
    # 獲取所有非背景分類標籤值
    cols = tf.cast(tf.gather(clss, inds), dtype=tf.int32)
    print('cols', cols)
    # 將所有非背景分類標籤值擴大4倍作爲列向索引
    starts = tf.expand_dims(cols * 4, axis=1)
    print('starts', starts)
    starts_1 = starts + 1
    starts_2 = starts + 2
    starts_3 = starts + 3
    col_inds = tf.reshape(tf.concat([starts, starts_1, starts_2, starts_3], axis=1), (-1,))
    print('col_inds', col_inds)
    # 將所有分類非背景座標索引橫向擴大4倍作爲橫向索引
    row_inds = tf.reshape(tf.tile(tf.expand_dims(inds, axis=1), [1, 4]), (-1,))
    print('row_inds', row_inds)
    # 拼接橫向索引和列向索引
    row_col_inds = tf.concat([tf.expand_dims(row_inds, 1), tf.expand_dims(col_inds, 1)], axis=1)
    print('row_col_inds', row_col_inds)
    # 獲取目標框的座標信息
    updates_target_data = tf.reshape(tf.gather(bbox_target_data[:, 1:], inds), (-1,))
    print('updates_target_data', updates_target_data)
    # 定義權重值全爲1
    updates_inside_weight = tf.reshape(tf.tile((1.0, 1.0, 1.0, 1.0), [tf.shape(inds)[0]]), (-1,))
    print('updates_inside_weight', updates_inside_weight)
    # 將獲取的目標框座標信息反寫bbox_targets
    bbox_targets = tf.tensor_scatter_nd_update(bbox_targets, row_col_inds, updates_target_data)
    np.set_printoptions(threshold=np.inf)
    print('bbox_targets', tf.gather(bbox_targets, tf.unique(tf.where(bbox_targets > 0)[:, 0])[0]))
    # 將定義的權重值反寫bbox_inside_weights
    bbox_inside_weights = tf.tensor_scatter_nd_update(bbox_inside_weights, row_col_inds, updates_inside_weight)
    print('bbox_inside_weights', tf.gather(bbox_inside_weights, tf.unique(tf.where(bbox_inside_weights > 0)[:, 0])[0]))

    return bbox_targets, bbox_inside_weights

運行結果

clss tf.Tensor(
[2. 2. 2. 1. 2. 1. 2. 1. 1. 1. 2. 2. 2. 1. 2. 2. 1. 2. 1. 2. 2. 2. 2. 1.
 1. 2. 2. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.], shape=(256,), dtype=float32)
bbox_targets tf.Tensor(
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]], shape=(256, 84), dtype=float32)
bbox_inside_weights tf.Tensor(
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]], shape=(256, 84), dtype=float32)
inds tf.Tensor(
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27], shape=(28,), dtype=int32)
cols tf.Tensor([2 2 2 1 2 1 2 1 1 1 2 2 2 1 2 2 1 2 1 2 2 2 2 1 1 2 2 1], shape=(28,), dtype=int32)
starts tf.Tensor(
[[8]
 [8]
 [8]
 [4]
 [8]
 [4]
 [8]
 [4]
 [4]
 [4]
 [8]
 [8]
 [8]
 [4]
 [8]
 [8]
 [4]
 [8]
 [4]
 [8]
 [8]
 [8]
 [8]
 [4]
 [4]
 [8]
 [8]
 [4]], shape=(28, 1), dtype=int32)
col_inds tf.Tensor(
[ 8  9 10 11  8  9 10 11  8  9 10 11  4  5  6  7  8  9 10 11  4  5  6  7
  8  9 10 11  4  5  6  7  4  5  6  7  4  5  6  7  8  9 10 11  8  9 10 11
  8  9 10 11  4  5  6  7  8  9 10 11  8  9 10 11  4  5  6  7  8  9 10 11
  4  5  6  7  8  9 10 11  8  9 10 11  8  9 10 11  8  9 10 11  4  5  6  7
  4  5  6  7  8  9 10 11  8  9 10 11  4  5  6  7], shape=(112,), dtype=int32)
row_inds tf.Tensor(
[ 0  0  0  0  1  1  1  1  2  2  2  2  3  3  3  3  4  4  4  4  5  5  5  5
  6  6  6  6  7  7  7  7  8  8  8  8  9  9  9  9 10 10 10 10 11 11 11 11
 12 12 12 12 13 13 13 13 14 14 14 14 15 15 15 15 16 16 16 16 17 17 17 17
 18 18 18 18 19 19 19 19 20 20 20 20 21 21 21 21 22 22 22 22 23 23 23 23
 24 24 24 24 25 25 25 25 26 26 26 26 27 27 27 27], shape=(112,), dtype=int32)
row_col_inds tf.Tensor(
[[ 0  8]
 [ 0  9]
 [ 0 10]
 [ 0 11]
 [ 1  8]
 [ 1  9]
 [ 1 10]
 [ 1 11]
 [ 2  8]
 [ 2  9]
 [ 2 10]
 [ 2 11]
 [ 3  4]
 [ 3  5]
 [ 3  6]
 [ 3  7]
 [ 4  8]
 [ 4  9]
 [ 4 10]
 [ 4 11]
 [ 5  4]
 [ 5  5]
 [ 5  6]
 [ 5  7]
 [ 6  8]
 [ 6  9]
 [ 6 10]
 [ 6 11]
 [ 7  4]
 [ 7  5]
 [ 7  6]
 [ 7  7]
 [ 8  4]
 [ 8  5]
 [ 8  6]
 [ 8  7]
 [ 9  4]
 [ 9  5]
 [ 9  6]
 [ 9  7]
 [10  8]
 [10  9]
 [10 10]
 [10 11]
 [11  8]
 [11  9]
 [11 10]
 [11 11]
 [12  8]
 [12  9]
 [12 10]
 [12 11]
 [13  4]
 [13  5]
 [13  6]
 [13  7]
 [14  8]
 [14  9]
 [14 10]
 [14 11]
 [15  8]
 [15  9]
 [15 10]
 [15 11]
 [16  4]
 [16  5]
 [16  6]
 [16  7]
 [17  8]
 [17  9]
 [17 10]
 [17 11]
 [18  4]
 [18  5]
 [18  6]
 [18  7]
 [19  8]
 [19  9]
 [19 10]
 [19 11]
 [20  8]
 [20  9]
 [20 10]
 [20 11]
 [21  8]
 [21  9]
 [21 10]
 [21 11]
 [22  8]
 [22  9]
 [22 10]
 [22 11]
 [23  4]
 [23  5]
 [23  6]
 [23  7]
 [24  4]
 [24  5]
 [24  6]
 [24  7]
 [25  8]
 [25  9]
 [25 10]
 [25 11]
 [26  8]
 [26  9]
 [26 10]
 [26 11]
 [27  4]
 [27  5]
 [27  6]
 [27  7]], shape=(112, 2), dtype=int32)
updates_target_data tf.Tensor(
[ 0.5039273   1.9543941   0.39690098 -0.808535   -0.9039384  -1.3870759
 -1.4696271   0.42652082 -0.28305605 -0.1425246  -1.4708441   0.42647812
 -1.5528293  -0.81134266  1.1135973   2.0942068   0.9649968  -1.386991
 -1.4698946   0.42687    -0.9716667   1.6943251  -1.5394237   1.5907156
 -2.2117233   0.8024186   0.39624962 -1.1595803  -2.4491363   0.27737466
  1.1158336   1.0569729  -1.3747127  -0.32095075 -0.75373423  1.6534086
 -0.7521222   0.8893123  -0.7552303   0.36309978  0.33687025  1.1044904
 -1.4701374   0.4267589   1.4062967   0.34899122  0.3959936  -1.158884
 -0.402098    0.34905997  0.39632288 -1.1589644   0.11842789  0.03336683
 -0.34195063  1.5918659  -1.3035506  -1.0094306   0.39754125 -1.1612608
  2.205167    1.1041038  -1.4706837   0.42680323 -1.5475154   1.2948655
  1.116524    0.5621466   1.585693   -0.14235474 -1.4710339   0.42653617
 -0.63558406  0.27993834  1.1161374   1.0542582  -0.39966    -2.3699164
  0.39738548 -1.1610215   2.3146038  -1.0092112   0.39724234 -1.1606102
  0.50494    -1.009559    0.3974295  -1.1610612  -0.9082577   1.7271762
 -1.4701273   0.42750752  0.91846555  1.2800353   0.38419884  1.5924516
  0.7224141   0.9540003   1.1167352   0.28379184 -2.154466    1.1046239
 -1.4700354   0.4270566  -1.304847    1.954505    0.39715976 -0.8084551
  2.356012    0.6044802   1.471402    0.6980772 ], shape=(112,), dtype=float32)
updates_inside_weight tf.Tensor(
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(112,), dtype=float32)
bbox_targets tf.Tensor(
[[ 0.          0.          0.          0.         -1.550683    0.9540003
   1.1177553   0.28379184  0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          2.7304864   0.74784005
   1.7196715   0.53228384  0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.         -0.9716667   2.1129775
  -1.5394237   1.5833551   0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.         -1.0994201  -0.814709
   1.1128297   2.0971014   0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.          1.5772034   1.7348057  -1.4737666   0.43069723
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          1.4034343   2.1094267
   0.77798486  1.5891427   0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.         -0.28987557 -0.13615295 -1.4732219   0.43208003
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.         -1.3157098   0.35349324  0.39095744 -1.1567494
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.          2.2996747  -0.10208578  0.39128652 -1.1567539
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.         -1.3143069  -1.4616526   0.38927484 -1.1578193
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.29374614  0.7480484
  -0.1735566   0.5320389   0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.         -0.91710573 -2.002599   -1.4728137   0.43223587
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.         -0.19354807  2.2872744
   1.1169072   1.294123    0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.         -0.91302526  1.735981   -1.4740267   0.4310986
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.          0.49308506 -0.5556057   0.39094916 -1.1566304
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.          0.32825857 -1.3804228  -1.4726522   0.4319018
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.          0.4875352  -2.3664892   0.38940057 -1.1574101
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.         -0.74883765 -0.3287388
  -0.7502073   1.6607213   0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.         -1.5336581   0.48841372 -1.4740326   0.43166286
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.          1.5748384  -0.7580433  -1.472803    0.43178207
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.6897965  -1.4012557
   0.18718511  2.57762     0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.         -1.6867362   0.88750476
  -0.75459015  0.36529765  0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          1.4048017   0.03480625
   0.7790525   1.5890732   0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.          0.4895188   1.2608492   0.39079872 -1.1569301
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.         -0.18902959  0.26397693
   1.1133983   1.0711371   0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.          0.9561669   0.4879443  -1.4740443   0.431881
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]], shape=(26, 84), dtype=float32)
bbox_inside_weights tf.Tensor(
[[0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(26, 84), dtype=float32)

我們大致來理一下前面的路線,一張圖片-->Vgg16-->feature map-->Anchor機制-->獲取所有Bounding box座標。

一張圖片-->Vgg16-->feature map-->通過兩個1*1的卷積核來提取feature map每一個像素的分類和位置信息(偏移值)-->通過上面的Bounding box座標和位置信息(偏移值)-->候選區域-->限制在原始圖片區域邊界-->非極大值抑制-->非極大值抑制後的Boxes-->跟標註區域計算IOU-->獲取每一個非極大值抑制後的Boxes的最大IOU(跟哪一個標註區域的IOU大就取哪個),以及標註區域的最大IOU(標註區域跟非極大值抑制後的Boxes產生的最大IOU)的分類標註-->將非極大值抑制後的Boxes的最大IOU>=0.5的Boxes設置爲前景框,將非極大值抑制後的Boxes的最大IOU大於等於0.1小於0.5的設置爲背景框-->將前景框的數量限制在128個之內-->將背景框的數量限制在256-前景框數量之內-->從非極大值抑制後的Boxes中提取帶分類標註的前背景框(256個)-->該前背景框與標註框生成帶分類標註的目標框-->將該目標框的信息存儲到一個全分類的大矩陣中

一張圖片-->Vgg16-->feature map-->Anchor機制-->獲取所有Bounding box座標-->限制在原始圖片區域邊界-->跟標註區域計算IOU-->獲取每一個Bounding box的最大IOU(跟哪一個標註區域的IOU大就取哪個),以及標註區域的最大IOU(標註區域跟Bounding box產生的最大IOU)

一張圖片-->Vgg16-->feature map-->Anchor機制-->獲取所有Bounding box座標-->限制在原始圖片區域邊界-->生成一個相同條數的全爲-1的標籤-->通過上面計算的IOU<0.3的Bounding box的序號在標籤中更新爲0-->獲取跟標註區域最大IOU相同的Bounding box的序號在標籤中更新爲1-->通過上面計算的IOU>0.7的Bounding box的序號在標籤中更新爲1-->將標籤中爲1的總數限制在128個之內,超過128的更新爲0-->將標籤中爲0的總數限制在256個之內,超過256的更新爲-1-->將標籤恢復到原始Anchor的Bounding box的數量

一張圖片-->Vgg16-->feature map-->Anchor機制-->獲取所有Bounding box座標-->限制在原始圖片區域邊界-->提取每一個最大IOU的Bounding box的標註邊框的座標-->由限制在原始圖片區域內的Bounding box的座標和相應的最大IOU的標註邊框的座標生成目標邊框的座標-->將該目標邊框恢復到原始Anchor的Bouning box的數量

一張圖片-->Vgg16-->feature map-->Anchor機制-->獲取所有Bounding box座標-->限制在原始圖片區域邊界-->生成一個相同條數的全爲0的4維向量的前景標籤-->將該前景標籤與標籤爲1的索引的相同索引部分也全部更新爲全爲1的4維向量-->將該前景標籤恢復到原始Anchor的Bouning box的數量

一張圖片-->Vgg16-->feature map-->Anchor機制-->獲取所有Bounding box座標-->限制在原始圖片區域邊界-->生成一個相同條數的全爲0的4維向量的前背景權重-->將該前背景權重按照標籤爲1或者爲0的索引相同的索引更新爲相應的權重——前背景的權重都爲1/前背景總數256=0.00390625-->將前背景的權重恢復到原始Anchor的Bounding box的數量

region_proposal_network函數繼續

# 將恢復到原始Anchor的Bounding box數量的labels、目標邊框、前景標籤和前背景權重放入字典中
anchor_targets['rpn_labels'] = rpn_labels
# [1,height,width, 9*4]
anchor_targets['rpn_bbox_targets'] = rpn_bbox_targets
# [1,height,width, 9*4]
anchor_targets['rpn_bbox_inside_weights'] = rpn_bbox_inside_weights
# [1,height,width, 9*4]
anchor_targets['rpn_bbox_outside_weights'] = rpn_bbox_outside_weights

# 將非極大值抑制的分類標籤、目標框、前景標籤和前背景權重放入字典中
# [256, 1]
proposal_targets['labels'] = labels
# [256, 4 * num_classes]
proposal_targets['bbox_targets'] = bbox_targets
# [256, 4 * num_classes]
proposal_targets['bbox_inside_weights'] = bbox_inside_weights
# [256, 4 * num_classes]
proposal_targets['bbox_outside_weights'] = bbox_outside_weights

# [1, h, w, 9*2]
predictions["rpn_cls_score"] = rpn_cls_score
# [1, h*9, w, 2]
predictions["rpn_cls_score_reshape"] = rpn_cls_score_reshape
predictions["rpn_cls_prob_reshape"] = rpn_cls_prob_reshape
# [1, h, w, 9*2]
predictions["rpn_cls_prob"] = rpn_cls_prob
# [h*w*9, 1]
predictions["rpn_cls_pred"] = rpn_cls_pred
# [1, h, w, 9*4]
predictions["rpn_bbox_pred"] = rpn_bbox_pred
# [256, 5]
predictions["rois"] = rois

return rois, roi_scores, labels, anchor_targets, proposal_targets, predictions

然後就是非訓練的預測代碼

else:
    # 預測的邊框與anchors進行比對, 非極大抑制後輸出最終目標邊框[[idx, x1, y1, x2, y2],...]及其分值
    scores = tf.reshape(rpn_cls_prob, (-1, 2))[:, 1]
    # scores = rpn_cls_prob[:, :, :, self.num_anchors:]
    # scores = tf.reshape(scores, shape=(-1,))
    rpn_bbox_pred = tf.reshape(rpn_bbox_pred, shape=(-1, 4))
    # 根據anchors和偏移量得到proposals(候選區域)
    proposals = bbox_transform_inv_tf(anchors, rpn_bbox_pred)
    # 調整boxes的座標,使其全部在圖像的範圍內, 全部大於0,小於圖像寬高
    proposals = clip_boxes_tf(proposals, im_info)

    # 非極大值抑制,輸出的索引號
    indices = tf.image.non_max_suppression(boxes=proposals,
                                           scores=scores,
                                           max_output_size=2000,
                                           iou_threshold=0.7)
    # 根據索引號輸出對應proposals
    boxes = tf.gather(proposals, indices)
    boxes = tf.cast(boxes, tf.float32)
    # 根據索引號輸出對應分數
    scores = tf.gather(scores, indices)
    scores = tf.reshape(scores, shape=(-1, 1))

    # Only support single image as input
    batch_inds = tf.zeros((tf.shape(indices)[0], 1), dtype=tf.float32)
    # 給輸出的proposals添加一個全爲0的維度
    blob = tf.concat([batch_inds, boxes], 1)
    return blob, scores

 通過了RPN網絡,現在就是對feature map進行ROI Pooling了

""" 裁剪層, 對卷積網絡層輸出的特徵, 根據rpn層輸出的roi進行裁剪, 且resize到統一的大小

        :return [bbox_nums, pre_pool_size, pre_pool_size, depth]
        """
# 獲取批量feature map的索引,由於我們只引入了一張圖片,所以這裏都是0
batch_ids = tf.squeeze(tf.slice(rois, [0, 0], [-1, 1]), [1])
print('batch_ids', batch_ids)
# feature map的高和寬
height = im_info[0]
width = im_info[1]
# 獲取非極大值抑制前背景邊框在feature map上的座標
x1 = tf.expand_dims(rois[:, 1] / width, 1)
y1 = tf.expand_dims(rois[:, 2] / height, 1)
x2 = tf.expand_dims(rois[:, 3] / width, 1)
y2 = tf.expand_dims(rois[:, 4] / height, 1)
# 組合成座標框
bboxes = tf.concat([y1, x1, y2, x2], axis=1)
print('bboxes', bboxes)
pool_size_after_rpn = 7
pre_pool_size = pool_size_after_rpn * 2
# [bbox_nums, pre_pool_size, pre_pool_size, depth]
# 將非極大值抑制的前背景邊框在feature map上進行裁剪
crops = tf.image.crop_and_resize(image=feature_map,
                                 boxes=bboxes,
                                 box_indices=tf.cast(batch_ids, dtype=tf.int32),
                                 crop_size=[pre_pool_size, pre_pool_size])
print('crops', crops)
# 將裁剪後的區域進行下采樣
pool5 = layers.MaxPooling2D(pool_size=(2, 2), padding='SAME')(crops)
print('pool5', pool5)

運行結果

batch_ids tf.Tensor(
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.], shape=(256,), dtype=float32)
bboxes tf.Tensor(
[[0.         0.         0.6554428  0.7592091 ]
 [0.         0.21326488 0.68766814 0.50664294]
 [0.27119887 0.0402505  0.78392434 0.46705002]
 ...
 [0.         0.         0.998      0.44062048]
 [0.59157115 0.120248   0.8476926  0.3336019 ]
 [0.52808124 0.23404129 0.71991867 0.5404586 ]], shape=(256, 4), dtype=float32)
crops tf.Tensor(
[[[[2.57941138e-05 0.00000000e+00 0.00000000e+00 ... 2.36172258e-04
    7.93727158e-05 0.00000000e+00]
   [8.50972065e-05 0.00000000e+00 1.97236532e-05 ... 2.03158008e-04
    1.04102823e-04 0.00000000e+00]
   [9.90348781e-05 0.00000000e+00 4.71836938e-06 ... 2.40887384e-04
    1.28896078e-04 0.00000000e+00]
   ...
   [5.45808944e-05 0.00000000e+00 0.00000000e+00 ... 2.19015754e-04
    1.25698309e-04 0.00000000e+00]
   [5.84022819e-05 0.00000000e+00 0.00000000e+00 ... 2.26729724e-04
    1.19131575e-04 0.00000000e+00]
   [4.23501624e-05 0.00000000e+00 0.00000000e+00 ... 2.16595290e-04
    1.17535099e-04 0.00000000e+00]]

  [[1.06878824e-05 0.00000000e+00 2.75870389e-05 ... 2.37441360e-04
    7.53409113e-05 4.63427714e-05]
   [6.47986817e-05 3.37529059e-06 1.78502582e-04 ... 2.42793089e-04
    1.18515294e-04 4.18921900e-06]
   [9.79786128e-05 2.87049916e-05 2.09538295e-04 ... 2.81922054e-04
    1.36703748e-04 3.62682986e-05]
   ...
   [1.51480563e-04 7.89926489e-05 2.05872828e-04 ... 2.86957249e-04
    8.29240744e-05 7.64512079e-05]
   [1.31847424e-04 6.91177847e-05 1.60385680e-04 ... 2.77276966e-04
    7.25277205e-05 4.62162352e-05]
   [1.31187742e-04 7.12195979e-05 1.70186089e-04 ... 2.80556647e-04
    8.17691252e-05 4.67670543e-05]]

  [[1.32335117e-04 0.00000000e+00 1.68638289e-05 ... 3.05730064e-04
    2.20419515e-05 3.22657179e-05]
   [2.54400977e-04 3.60248778e-05 2.27002281e-04 ... 2.58771091e-04
    1.50366786e-05 1.06731659e-05]
   [3.36324301e-04 1.31906185e-04 1.83529963e-04 ... 3.28293914e-04
    7.28004670e-05 6.49492504e-05]
   ...
   [3.09884024e-04 2.97420658e-04 8.68592178e-05 ... 3.40728555e-04
    2.09749578e-05 1.66992584e-04]
   [2.87153845e-04 2.95738544e-04 5.26858348e-05 ... 3.20104475e-04
    3.90121277e-05 1.40761971e-04]
   [2.81091430e-04 2.76907143e-04 4.15307695e-05 ... 3.25244851e-04
    2.87410712e-05 1.43403289e-04]]

  ...

  [[4.00163903e-04 1.73205146e-04 0.00000000e+00 ... 5.99760620e-04
    7.70042607e-05 2.36782784e-04]
   [3.54735559e-04 4.77489084e-04 1.95435568e-04 ... 1.81510055e-04
    2.23485476e-05 1.21864708e-04]
   [4.10365552e-04 5.52846701e-04 0.00000000e+00 ... 3.73273448e-04
    1.37151452e-04 1.29798456e-04]
   ...
   [7.93222862e-04 1.02175050e-03 1.73575245e-05 ... 4.52198234e-04
    1.65508041e-04 1.46647610e-04]
   [8.44109687e-04 1.02800957e-03 7.77655441e-05 ... 4.21393663e-04
    1.17277275e-04 1.63938501e-04]
   [8.70423799e-04 1.02981832e-03 8.22859438e-05 ... 4.52392735e-04
    1.35761511e-04 1.45091486e-04]]

  [[4.23425372e-04 2.13654901e-04 1.11948111e-05 ... 5.97584643e-04
    7.09810411e-05 3.11441749e-04]
   [4.35278693e-04 6.12199074e-04 3.13573953e-04 ... 1.45061436e-04
    3.80178390e-05 2.83967966e-04]
   [5.83942339e-04 7.69839040e-04 7.70668266e-05 ... 3.51703726e-04
    2.14247673e-04 3.51629918e-04]
   ...
   [7.77988986e-04 1.09826110e-03 4.55462214e-05 ... 5.82075038e-04
    2.55290361e-04 3.03152483e-04]
   [8.31034034e-04 1.08554412e-03 9.32276380e-05 ... 5.16735949e-04
    2.16363420e-04 2.67259136e-04]
   [8.73312238e-04 1.07251981e-03 9.60405305e-05 ... 4.84739925e-04
    2.36389606e-04 2.66733638e-04]]

  [[4.70805884e-04 2.58358952e-04 1.03791654e-05 ... 6.89105305e-04
    9.89125983e-05 2.97268212e-04]
   [5.56657731e-04 8.08325654e-04 3.33336327e-04 ... 1.55417743e-04
    3.30060211e-05 2.02990064e-04]
   [7.99033034e-04 9.92372865e-04 2.59754946e-04 ... 3.62182618e-04
    2.69023440e-04 2.29006691e-04]
   ...
   [9.61221871e-04 1.43929035e-03 6.12868831e-08 ... 7.45753001e-04
    1.71639709e-04 3.78454803e-04]
   [9.39709542e-04 1.36330479e-03 9.20240564e-06 ... 6.75685180e-04
    1.67544495e-04 3.15661775e-04]
   [9.21951723e-04 1.31815369e-03 4.35110087e-06 ... 6.54781819e-04
    1.83739350e-04 2.79963948e-04]]]


 [[[1.14292176e-04 0.00000000e+00 0.00000000e+00 ... 2.61187000e-04
    1.54441528e-04 0.00000000e+00]
   [1.15487259e-04 0.00000000e+00 0.00000000e+00 ... 2.57942273e-04
    1.58106530e-04 0.00000000e+00]
   [1.18072268e-04 0.00000000e+00 2.36116080e-06 ... 2.58158252e-04
    1.58092109e-04 0.00000000e+00]
   ...
   [9.58169112e-05 0.00000000e+00 0.00000000e+00 ... 2.57405161e-04
    1.41123091e-04 0.00000000e+00]
   [1.14460257e-04 0.00000000e+00 0.00000000e+00 ... 2.44720723e-04
    1.54873647e-04 0.00000000e+00]
   [1.04984923e-04 0.00000000e+00 0.00000000e+00 ... 2.46531767e-04
    1.40386270e-04 0.00000000e+00]]

  [[1.56408452e-04 8.07767792e-05 2.31481070e-04 ... 2.99894164e-04
    1.05128151e-04 5.90702984e-05]
   [1.45224069e-04 8.60906657e-05 2.45589996e-04 ... 2.97927647e-04
    1.08017208e-04 6.20882420e-05]
   [1.46472827e-04 9.53153794e-05 2.53096281e-04 ... 2.99713487e-04
    1.06450905e-04 5.38268541e-05]
   ...
   [1.49868545e-04 1.29032196e-04 2.96170707e-04 ... 3.34614859e-04
    1.17912736e-04 5.57913372e-05]
   [1.71052292e-04 1.36801289e-04 2.99562584e-04 ... 3.36776080e-04
    9.38532976e-05 5.32444756e-05]
   [1.82537828e-04 1.26220606e-04 2.98571482e-04 ... 3.38142330e-04
    6.34961907e-05 6.14306919e-05]]

  [[3.61496146e-04 1.92311549e-04 1.76463276e-04 ... 3.77393269e-04
    4.07109728e-05 8.01406204e-05]
   [3.56555742e-04 2.01838629e-04 1.61351578e-04 ... 3.82751721e-04
    4.98999179e-05 8.47078991e-05]
   [3.46725545e-04 2.10407772e-04 1.43619458e-04 ... 3.86187283e-04
    4.48260798e-05 8.73243334e-05]
   ...
   [3.08888993e-04 2.85816786e-04 1.29731765e-04 ... 3.83631006e-04
    7.58522729e-05 8.22020957e-05]
   [3.07592360e-04 3.12648306e-04 1.40574935e-04 ... 3.67946166e-04
    4.52013810e-05 9.15492274e-05]
   [3.35881632e-04 3.06333241e-04 1.44492864e-04 ... 3.59103025e-04
    3.26196314e-05 1.11181536e-04]]

  ...

  [[4.92547872e-04 7.21822784e-04 1.44399091e-05 ... 3.64700332e-04
    2.20407193e-04 3.11683718e-04]
   [4.91031213e-04 7.47851445e-04 1.08267395e-05 ... 3.52596398e-04
    2.30913283e-04 3.47128924e-04]
   [4.94010281e-04 7.70167971e-04 3.79991179e-05 ... 3.75551754e-04
    2.43150527e-04 3.66975903e-04]
   ...
   [6.55897253e-04 8.56438361e-04 1.08470103e-05 ... 4.81539086e-04
    2.20026661e-04 3.22921318e-04]
   [6.42941275e-04 8.20296467e-04 2.17795132e-05 ... 4.74174594e-04
    2.18805508e-04 3.15510581e-04]
   [6.42917876e-04 8.32928054e-04 8.85983809e-06 ... 4.82336851e-04
    2.29498386e-04 2.95773876e-04]]

  [[9.05491994e-04 9.19335929e-04 3.83307197e-04 ... 4.68881160e-04
    2.28007630e-04 3.14492034e-04]
   [9.10450704e-04 9.15617391e-04 4.17571864e-04 ... 4.79412265e-04
    2.07017452e-04 2.87801260e-04]
   [8.82467779e-04 8.92415992e-04 4.21877485e-04 ... 4.94778913e-04
    1.79783412e-04 2.92004523e-04]
   ...
   [7.88566831e-04 1.05812831e-03 9.80526893e-05 ... 6.63745101e-04
    2.10895989e-04 3.82266764e-04]
   [7.98457128e-04 1.08407438e-03 1.00612728e-04 ... 7.10698718e-04
    1.57261107e-04 4.41840442e-04]
   [7.87607220e-04 1.10880146e-03 6.71890739e-05 ... 7.17649586e-04
    2.56754225e-04 5.18647081e-04]]

  [[9.69761284e-04 1.10945990e-03 5.32542355e-04 ... 4.78216331e-04
    3.27725691e-04 2.03786913e-04]
   [1.03876845e-03 1.01878936e-03 5.40088280e-04 ... 5.48317737e-04
    3.76438722e-04 2.04635915e-04]
   [1.12292275e-03 9.68988868e-04 5.14667016e-04 ... 6.04655826e-04
    3.25021858e-04 1.88193735e-04]
   ...
   [1.14176655e-03 1.29631220e-03 9.56789154e-05 ... 9.47466353e-04
    3.03222303e-04 5.65312686e-04]
   [1.31252268e-03 1.16518151e-03 2.37259606e-04 ... 1.03183591e-03
    2.45521573e-04 4.82056261e-04]
   [1.30540761e-03 1.41866598e-03 2.90135562e-04 ... 1.04595663e-03
    3.01967870e-04 5.11657796e-04]]]


 [[[4.79647017e-04 3.87640757e-04 1.68047336e-04 ... 5.88261173e-04
    1.00279452e-04 6.57750861e-05]
   [5.53639547e-04 6.53389783e-04 1.95391593e-04 ... 2.35194544e-04
    1.20024000e-04 1.06910578e-04]
   [5.74998965e-04 7.42716133e-04 3.66606037e-05 ... 3.20683554e-04
    2.32117323e-04 1.17069409e-04]
   ...
   [7.27252045e-04 9.64986102e-04 1.16493226e-04 ... 4.43311525e-04
    1.64645622e-04 7.28754676e-05]
   [7.48971070e-04 1.01204170e-03 1.00852703e-04 ... 4.41888376e-04
    1.41894940e-04 1.24709157e-04]
   [7.43874058e-04 9.38375015e-04 1.30232700e-04 ... 4.41436598e-04
    1.82933392e-04 1.53388872e-04]]

  [[5.85690024e-04 4.73585649e-04 2.23462746e-04 ... 5.56051149e-04
    1.17017524e-04 8.36868421e-05]
   [7.10854481e-04 7.37991475e-04 2.93421763e-04 ... 2.46162468e-04
    1.28440588e-04 1.06419961e-04]
   [7.33211753e-04 8.04934942e-04 9.68017193e-05 ... 3.23428569e-04
    2.46460229e-04 8.29147903e-05]
   ...
   [8.77299171e-04 1.02991890e-03 1.62051467e-04 ... 4.85595054e-04
    2.02039286e-04 5.81603417e-05]
   [8.85637710e-04 1.07449258e-03 1.47312472e-04 ... 4.40083735e-04
    1.95041473e-04 9.49656896e-05]
   [8.42506299e-04 1.02885871e-03 1.75801484e-04 ... 3.82116326e-04
    1.89019265e-04 1.66384620e-04]]

  [[6.66236912e-04 5.35579340e-04 2.52472091e-04 ... 4.81147639e-04
    1.25439314e-04 8.94362529e-05]
   [8.49725155e-04 8.24204239e-04 3.59417958e-04 ... 2.24123141e-04
    1.38298710e-04 8.49468561e-05]
   [8.97586113e-04 9.07797425e-04 1.95347908e-04 ... 2.88083102e-04
    2.58876535e-04 8.90067968e-05]
   ...
   [1.02446578e-03 1.11514796e-03 1.93580330e-04 ... 3.84168117e-04
    2.68174626e-04 1.21962890e-04]
   [1.03215850e-03 1.06214103e-03 2.03652569e-04 ... 3.96692572e-04
    2.74504302e-04 1.63052711e-04]
   [9.78507218e-04 1.03824073e-03 2.22774572e-04 ... 3.46564251e-04
    2.65280018e-04 1.89962855e-04]]

  ...

  [[7.66651821e-04 5.99849678e-04 2.78109917e-04 ... 5.95241901e-04
    8.23303999e-05 8.41138753e-05]
   [9.83987935e-04 9.88321030e-04 4.54014225e-04 ... 3.21015978e-04
    1.41805955e-04 1.39067197e-04]
   [9.85957566e-04 1.11822167e-03 3.70956404e-04 ... 4.05550731e-04
    3.57914105e-04 1.55227608e-04]
   ...
   [1.03025604e-03 1.45997375e-03 4.29865904e-04 ... 7.84392003e-04
    3.37503589e-04 6.67511718e-04]
   [1.14398613e-03 1.20148691e-03 3.36048077e-04 ... 7.21622026e-04
    3.90748377e-04 9.65566840e-04]
   [1.32705329e-03 1.40124734e-03 2.05629651e-04 ... 1.06697250e-03
    3.55891534e-04 7.16972398e-04]]

  [[8.05449265e-04 5.56517858e-04 3.07159062e-04 ... 4.60833369e-04
    1.55441317e-04 8.19877605e-05]
   [1.08476111e-03 1.00947835e-03 6.09279727e-04 ... 3.26015899e-04
    1.67370192e-04 4.96004359e-05]
   [1.02966593e-03 1.07537699e-03 4.25110775e-04 ... 3.66577064e-04
    3.11424199e-04 2.43929458e-06]
   ...
   [1.24914933e-03 1.44378655e-03 7.74317828e-04 ... 7.39179144e-04
    5.95461112e-04 5.30979130e-04]
   [1.02836476e-03 1.21121423e-03 2.35087966e-04 ... 2.15708787e-04
    9.09436902e-04 9.87315550e-04]
   [1.58135383e-03 1.49728323e-03 1.85505167e-04 ... 6.38003694e-04
    5.55639039e-04 7.39106617e-04]]

  [[7.47761223e-04 5.48799580e-04 2.35721091e-04 ... 3.06742673e-04
    1.47795188e-04 4.53219727e-05]
   [9.91690671e-04 7.43779936e-04 6.00897823e-04 ... 3.46555287e-04
    1.31113498e-04 2.74186004e-05]
   [1.08651770e-03 8.45953182e-04 4.65534744e-04 ... 4.36148926e-04
    2.97523104e-04 4.54195360e-05]
   ...
   [1.37088553e-03 1.34565029e-03 7.37168011e-04 ... 6.08214410e-04
    7.09378859e-04 6.31709263e-05]
   [9.86919273e-04 1.56252272e-03 2.48303055e-04 ... 4.91216837e-04
    9.19496117e-04 2.92584358e-04]
   [1.63699570e-03 1.48734974e-03 2.82912195e-04 ... 8.00483394e-04
    7.23995268e-04 3.32653835e-05]]]


 ...


 [[[2.57941138e-05 0.00000000e+00 0.00000000e+00 ... 2.36172258e-04
    7.93727158e-05 0.00000000e+00]
   [8.20266214e-05 0.00000000e+00 5.85429989e-05 ... 2.58198881e-04
    1.16877309e-04 0.00000000e+00]
   [9.37852019e-05 0.00000000e+00 1.15603352e-05 ... 2.22092640e-04
    9.59228782e-05 0.00000000e+00]
   ...
   [1.13933660e-04 0.00000000e+00 1.24028375e-05 ... 2.93586199e-04
    1.55640446e-04 0.00000000e+00]
   [1.04942344e-04 0.00000000e+00 5.21860738e-06 ... 2.94327183e-04
    1.34783288e-04 0.00000000e+00]
   [1.08170803e-04 0.00000000e+00 0.00000000e+00 ... 2.81353015e-04
    1.44668476e-04 0.00000000e+00]]

  [[6.02391083e-05 0.00000000e+00 2.11152910e-05 ... 3.06326925e-04
    4.06180370e-05 3.30004259e-05]
   [1.46032180e-04 8.07877711e-07 1.38469652e-04 ... 5.67830342e-04
    5.28193777e-05 0.00000000e+00]
   [1.94912136e-04 3.15086982e-05 1.96369307e-04 ... 3.07787122e-04
    7.58119131e-05 2.28821300e-05]
   ...
   [2.70474469e-04 2.03346164e-04 2.31856029e-04 ... 4.03760670e-04
    8.65431284e-05 6.59999350e-05]
   [2.83681438e-04 2.20964648e-04 2.18518777e-04 ... 4.18091251e-04
    9.40397294e-05 4.95874083e-05]
   [2.67399941e-04 2.20694666e-04 2.30307312e-04 ... 4.05363680e-04
    1.12200447e-04 6.26747642e-05]]

  [[1.89685903e-04 3.55363700e-05 2.32933271e-05 ... 4.93497471e-04
    2.35557200e-05 7.47200975e-05]
   [2.92095036e-04 6.91618261e-05 1.80338466e-04 ... 7.50928477e-04
    5.37344567e-05 1.05020772e-05]
   [3.32954136e-04 2.73171463e-04 1.58398238e-04 ... 2.97512859e-04
    3.79064804e-05 8.42474110e-05]
   ...
   [3.82248050e-04 4.18616895e-04 0.00000000e+00 ... 4.87582904e-04
    1.25062502e-06 1.73938199e-04]
   [3.91685782e-04 4.23261401e-04 0.00000000e+00 ... 5.04005991e-04
    2.78614480e-05 1.62033917e-04]
   [4.19450284e-04 4.84457705e-04 2.13090834e-06 ... 4.99198330e-04
    5.88202456e-05 1.61390140e-04]]

  ...

  [[3.99380311e-04 2.89219141e-04 0.00000000e+00 ... 2.57828011e-04
    3.42661602e-04 7.17001603e-05]
   [6.28836395e-04 2.57220498e-04 8.22884831e-05 ... 4.46113787e-04
    1.20095989e-04 0.00000000e+00]
   [9.53612092e-04 4.87815472e-04 3.15371610e-04 ... 1.94154447e-04
    1.32698086e-04 0.00000000e+00]
   ...
   [1.17775449e-03 7.94168154e-04 4.24293539e-04 ... 5.99301828e-04
    2.24827032e-04 0.00000000e+00]
   [1.19000452e-03 1.00587041e-03 2.09200676e-04 ... 5.01827628e-04
    5.27652330e-04 0.00000000e+00]
   [1.45661819e-03 1.08039204e-03 8.62712332e-05 ... 4.68577840e-04
    4.45507700e-04 0.00000000e+00]]

  [[2.88542476e-04 5.38287022e-05 2.51038437e-05 ... 1.52726279e-04
    3.10852265e-05 1.17608150e-04]
   [4.52078064e-04 9.56660151e-05 1.05503823e-05 ... 3.47232562e-04
    1.62081036e-04 0.00000000e+00]
   [5.59847278e-04 1.93173692e-04 5.31187579e-05 ... 2.44126190e-04
    1.93196203e-04 0.00000000e+00]
   ...
   [9.88315092e-04 4.22145182e-04 2.19815993e-04 ... 4.68359533e-04
    3.39266815e-04 0.00000000e+00]
   [7.75444380e-04 5.27048309e-04 1.62256096e-04 ... 2.92853249e-04
    2.84789770e-04 0.00000000e+00]
   [7.08071049e-04 5.12086088e-04 5.55136830e-05 ... 4.42705757e-04
    2.23171402e-04 0.00000000e+00]]

  [[1.11042958e-04 2.40066656e-05 5.46023693e-05 ... 8.03232979e-06
    3.36219382e-05 3.33587050e-05]
   [1.43170881e-04 1.03278900e-04 2.19094254e-05 ... 9.67486412e-05
    1.29809836e-04 0.00000000e+00]
   [2.24415126e-04 1.44215664e-04 6.00816456e-05 ... 6.09027629e-05
    1.16320029e-04 0.00000000e+00]
   ...
   [3.45969078e-04 1.60682786e-04 6.90437810e-05 ... 1.64366895e-04
    1.95574918e-04 0.00000000e+00]
   [3.19625513e-04 1.49326239e-04 6.25704924e-05 ... 1.87835845e-04
    1.77260488e-04 0.00000000e+00]
   [3.53661249e-04 1.67578488e-04 1.18228694e-04 ... 2.15443870e-04
    1.69462859e-04 0.00000000e+00]]]


 [[[5.20772650e-04 7.17530376e-04 3.58779689e-05 ... 3.45975568e-04
    2.08943908e-04 2.97852617e-04]
   [5.59217646e-04 7.27793842e-04 4.24444443e-05 ... 3.61273356e-04
    1.88140431e-04 3.00281012e-04]
   [5.55311446e-04 7.30294792e-04 5.08837802e-05 ... 3.67105094e-04
    1.94416483e-04 3.18186212e-04]
   ...
   [5.90510550e-04 8.23093695e-04 1.43360900e-04 ... 4.40091302e-04
    2.53883394e-04 4.18644137e-04]
   [6.14837860e-04 8.29096767e-04 1.52818335e-04 ... 4.61265096e-04
    2.60788103e-04 4.16681054e-04]
   [6.45729480e-04 8.35868821e-04 1.70048617e-04 ... 4.71232139e-04
    2.63854279e-04 4.14950453e-04]]

  [[6.25888933e-04 7.99034722e-04 1.00572826e-04 ... 3.54995893e-04
    2.08866506e-04 3.75401461e-04]
   [6.63125073e-04 8.16294865e-04 1.18979922e-04 ... 3.53491720e-04
    1.87272221e-04 3.72534676e-04]
   [6.71891845e-04 8.24449700e-04 1.42637000e-04 ... 3.81712394e-04
    2.12631887e-04 4.01621975e-04]
   ...
   [7.36258575e-04 8.97741294e-04 2.69548094e-04 ... 4.94365057e-04
    2.43709292e-04 4.57990915e-04]
   [7.53784319e-04 8.92954646e-04 2.76951439e-04 ... 5.10111917e-04
    2.43131566e-04 4.63571603e-04]
   [7.81960669e-04 8.89168936e-04 2.84024514e-04 ... 5.11906459e-04
    2.57374370e-04 4.63512319e-04]]

  [[7.17221119e-04 8.92501441e-04 1.79539493e-04 ... 3.56996054e-04
    2.39937566e-04 2.98422878e-04]
   [7.60827388e-04 9.11256473e-04 2.06148688e-04 ... 3.62094375e-04
    2.45100964e-04 2.98904139e-04]
   [7.67761900e-04 9.10447503e-04 2.38535256e-04 ... 3.88759887e-04
    2.53898848e-04 3.11482290e-04]
   ...
   [8.11977778e-04 9.08809714e-04 3.41151026e-04 ... 5.40397246e-04
    1.74034794e-04 3.38222890e-04]
   [7.76033965e-04 9.35521093e-04 3.38389422e-04 ... 5.44018927e-04
    2.06837314e-04 3.36885394e-04]
   [7.66875979e-04 9.56774573e-04 3.19680665e-04 ... 5.44338487e-04
    2.34574618e-04 3.59822821e-04]]

  ...

  [[1.12859358e-03 6.43191685e-04 5.30606951e-04 ... 3.10253527e-04
    2.45891046e-04 1.94746011e-04]
   [1.16281048e-03 5.63914422e-04 5.39824716e-04 ... 2.65649636e-04
    2.47119984e-04 3.36937548e-04]
   [1.12968171e-03 3.57370125e-04 4.36897244e-04 ... 2.95403646e-04
    2.89556105e-04 3.68416193e-04]
   ...
   [9.62707622e-04 9.29373666e-04 1.13741248e-06 ... 6.05435984e-04
    5.85725065e-04 6.25622051e-05]
   [1.03752664e-03 9.21816274e-04 6.50010361e-06 ... 5.07611083e-04
    6.29555550e-04 4.47476996e-05]
   [1.22413563e-03 8.94924859e-04 1.24285682e-04 ... 5.74408332e-04
    6.44093612e-04 2.43014347e-05]]

  [[9.98691190e-04 5.09886711e-04 4.30109561e-04 ... 2.81491928e-04
    2.67761585e-04 8.08586919e-05]
   [1.01846212e-03 4.97584173e-04 4.40791628e-04 ... 2.55223422e-04
    3.63496132e-04 1.50235297e-04]
   [9.59796365e-04 3.56290897e-04 4.21057572e-04 ... 2.54958140e-04
    4.46519349e-04 2.07246194e-04]
   ...
   [9.77899530e-04 7.70930899e-04 1.33313888e-05 ... 4.23199846e-04
    6.18940685e-04 9.58766832e-05]
   [1.09602069e-03 7.07459985e-04 7.61864430e-05 ... 4.44860751e-04
    6.37914229e-04 3.81376485e-05]
   [1.23866485e-03 7.13095302e-04 2.39224493e-04 ... 5.35024737e-04
    5.91180928e-04 8.56271436e-06]]

  [[8.77673156e-04 4.43988101e-04 2.89948599e-04 ... 3.07506212e-04
    2.90433032e-04 1.36353819e-05]
   [8.86665133e-04 4.70351282e-04 2.79261032e-04 ... 2.96458631e-04
    4.02505422e-04 3.51091185e-05]
   [8.42753332e-04 3.82200087e-04 2.98668543e-04 ... 2.88305251e-04
    4.80101822e-04 8.62669694e-05]
   ...
   [1.01879437e-03 6.68394496e-04 3.13637647e-05 ... 3.80928017e-04
    5.89259434e-04 8.22556467e-05]
   [1.12959323e-03 5.96140046e-04 8.92498283e-05 ... 4.42680932e-04
    5.82052278e-04 2.49227051e-05]
   [1.22049195e-03 6.23804168e-04 2.55639869e-04 ... 5.15470048e-04
    5.08123427e-04 0.00000000e+00]]]


 [[[4.61888034e-04 5.74681035e-04 0.00000000e+00 ... 4.63870616e-04
    1.68366809e-04 9.62797494e-05]
   [4.60850744e-04 5.70678792e-04 0.00000000e+00 ... 4.93708998e-04
    1.59319388e-04 1.07705157e-04]
   [4.72519343e-04 5.95845981e-04 0.00000000e+00 ... 5.24116040e-04
    1.53640241e-04 1.00333207e-04]
   ...
   [6.26584340e-04 8.52750614e-04 0.00000000e+00 ... 5.39756147e-04
    1.62976692e-04 3.92320471e-05]
   [6.24689157e-04 8.62612796e-04 0.00000000e+00 ... 5.12179977e-04
    1.87675672e-04 3.64286716e-05]
   [6.13135984e-04 9.01800522e-04 0.00000000e+00 ... 5.14095591e-04
    1.87846032e-04 3.83030638e-05]]

  [[4.45419661e-04 5.97980863e-04 0.00000000e+00 ... 4.28113592e-04
    1.45613943e-04 1.56428112e-04]
   [4.39676893e-04 6.09130948e-04 0.00000000e+00 ... 4.62111086e-04
    1.38474192e-04 1.62380602e-04]
   [4.48441919e-04 6.35805889e-04 0.00000000e+00 ... 4.99437447e-04
    1.43784302e-04 1.49476560e-04]
   ...
   [5.63490496e-04 8.76226521e-04 0.00000000e+00 ... 5.35950356e-04
    1.41612429e-04 8.76697668e-05]
   [5.83624758e-04 8.99769599e-04 0.00000000e+00 ... 5.38314052e-04
    1.44591701e-04 8.14052109e-05]
   [5.82287437e-04 9.15140612e-04 0.00000000e+00 ... 5.24791831e-04
    1.38312840e-04 8.55938124e-05]]

  [[4.50522086e-04 6.44689309e-04 8.08125321e-07 ... 3.95880081e-04
    1.63219709e-04 2.25162192e-04]
   [4.46118007e-04 6.63718849e-04 7.89247952e-06 ... 4.27531253e-04
    1.62523458e-04 2.34598847e-04]
   [4.54887806e-04 6.86803425e-04 1.70493913e-05 ... 4.63815610e-04
    1.72377200e-04 2.28848658e-04]
   ...
   [5.68082032e-04 8.65628128e-04 3.95144480e-06 ... 5.16483211e-04
    1.57699425e-04 1.60681346e-04]
   [5.90568758e-04 8.92725540e-04 0.00000000e+00 ... 5.29554265e-04
    1.57956878e-04 1.48768464e-04]
   [6.05779991e-04 9.09602793e-04 0.00000000e+00 ... 5.22175222e-04
    1.59066432e-04 1.46013481e-04]]

  ...

  [[1.02314772e-03 1.02470082e-03 5.34497551e-04 ... 5.24172559e-04
    3.78405442e-04 1.98040318e-04]
   [1.11832179e-03 9.65938845e-04 5.07238379e-04 ... 5.86462847e-04
    3.37835372e-04 1.87302081e-04]
   [1.20686961e-03 9.67154629e-04 5.29590819e-04 ... 6.28119044e-04
    2.16468965e-04 1.90514213e-04]
   ...
   [1.36575825e-03 1.23259658e-03 2.86309572e-04 ... 1.05184119e-03
    2.69733806e-04 4.91897343e-04]
   [1.33894070e-03 1.51298940e-03 3.13424243e-04 ... 1.02651678e-03
    3.40364699e-04 5.53083490e-04]
   [1.29753363e-03 1.68739317e-03 1.31865236e-04 ... 9.38953192e-04
    4.32954606e-04 6.91532739e-04]]

  [[9.73145943e-04 1.01111422e-03 5.01895964e-04 ... 4.25450678e-04
    4.15584334e-04 1.58703551e-04]
   [1.10974570e-03 9.45313601e-04 4.60589392e-04 ... 4.97402332e-04
    3.86787695e-04 1.78697956e-04]
   [1.26123580e-03 9.40008846e-04 5.03770367e-04 ... 5.82852750e-04
    2.62661255e-04 2.21097245e-04]
   ...
   [1.60075154e-03 1.18284114e-03 4.00233432e-04 ... 1.10176951e-03
    3.19109909e-04 5.22573013e-04]
   [1.58340554e-03 1.41211064e-03 5.02514536e-04 ... 9.86305764e-04
    3.55701544e-04 5.96669794e-04]
   [1.43797835e-03 1.70377118e-03 2.70389370e-04 ... 9.43403109e-04
    4.12347465e-04 7.20137963e-04]]

  [[9.24840802e-04 9.43241175e-04 4.28358675e-04 ... 3.43539927e-04
    3.70222959e-04 1.00103207e-04]
   [1.04659505e-03 9.33758973e-04 4.08484979e-04 ... 4.19819669e-04
    3.71288275e-04 1.26474042e-04]
   [1.19149499e-03 9.73262009e-04 5.00161899e-04 ... 5.74898790e-04
    3.00126034e-04 1.88329170e-04]
   ...
   [1.74198754e-03 1.19227497e-03 4.68753773e-04 ... 1.07544346e-03
    5.04795811e-04 5.51955018e-04]
   [1.81470963e-03 1.34030869e-03 6.32593466e-04 ... 9.80577897e-04
    4.68864338e-04 5.91416378e-04]
   [1.65601959e-03 1.67581788e-03 4.20510769e-04 ... 9.78714786e-04
    4.76391579e-04 7.12864101e-04]]]], shape=(256, 14, 14, 512), dtype=float32)
pool5 tf.Tensor(
[[[[8.50972065e-05 3.37529059e-06 1.78502582e-04 ... 2.42793089e-04
    1.18515294e-04 4.63427714e-05]
   [1.35565206e-04 5.10568934e-05 2.30162841e-04 ... 2.86505383e-04
    1.39620141e-04 4.65020203e-05]
   [1.34286442e-04 9.67838278e-05 2.58214015e-04 ... 2.91101082e-04
    1.70017491e-04 6.52667804e-05]
   ...
   [1.70279731e-04 1.15260889e-04 3.01139895e-04 ... 3.17129656e-04
    1.44594771e-04 6.12480944e-05]
   [1.63697070e-04 7.89926489e-05 2.43093004e-04 ... 3.14908626e-04
    1.31945475e-04 7.64512079e-05]
   [1.31847424e-04 7.12195979e-05 1.70186089e-04 ... 2.80556647e-04
    1.19131575e-04 4.67670543e-05]]

  [[3.00360349e-04 2.59432825e-04 2.27002281e-04 ... 4.82253876e-04
    2.27560395e-05 7.14753260e-05]
   [4.05739876e-04 3.39414401e-04 1.92057778e-04 ... 4.01965517e-04
    7.28004670e-05 1.19044242e-04]
   [3.88790737e-04 3.90990113e-04 1.84498887e-04 ... 4.53369168e-04
    5.50519544e-05 1.48656487e-04]
   ...
   [4.31883149e-04 4.90297913e-04 1.67487291e-04 ... 4.52676235e-04
    9.44988642e-05 1.62427401e-04]
   [4.27848106e-04 5.17275243e-04 1.25641236e-04 ... 3.77765682e-04
    4.52852619e-05 2.03212345e-04]
   [4.09843342e-04 5.41296555e-04 5.26858348e-05 ... 3.74409807e-04
    4.74188055e-05 2.24574003e-04]]

  [[4.17680072e-04 5.64322865e-04 1.66848084e-04 ... 6.27505884e-04
    5.76611892e-05 1.96642810e-04]
   [6.09540381e-04 7.20871496e-04 9.36713059e-06 ... 3.46617569e-04
    2.15021166e-04 1.77880778e-04]
   [6.04244415e-04 8.57890409e-04 1.72657542e-06 ... 3.57129145e-04
    2.43599890e-04 2.13057283e-04]
   ...
   [7.08987354e-04 9.00349580e-04 1.26482497e-04 ... 4.73266817e-04
    2.44078677e-04 2.19974521e-04]
   [7.28833955e-04 9.56389238e-04 2.00677296e-05 ... 4.02574224e-04
    1.60991476e-04 1.77947222e-04]
   [7.36390066e-04 1.03691837e-03 9.12738017e-07 ... 4.07314714e-04
    2.01484465e-04 1.78547882e-04]]

  ...

  [[8.11661244e-04 7.64584343e-04 3.58057208e-04 ... 4.32148459e-04
    1.47311133e-04 2.09028614e-04]
   [1.01326243e-03 9.41203849e-04 3.61180631e-04 ... 4.62141761e-04
    2.46342737e-04 5.17856870e-06]
   [1.00069912e-03 1.02589233e-03 3.98316915e-04 ... 5.01353585e-04
    2.31302183e-04 2.71863300e-05]
   ...
   [1.03372254e-03 1.16961531e-03 3.35510849e-04 ... 5.18574961e-04
    3.12637800e-04 2.72984580e-05]
   [1.10391201e-03 1.22411526e-03 3.08943971e-04 ... 5.32849226e-04
    2.55019724e-04 1.04102273e-05]
   [1.12674816e-03 1.30425429e-03 2.73656653e-04 ... 4.66701400e-04
    2.14466811e-04 3.56770252e-05]]

  [[5.13919047e-04 4.77489084e-04 1.95435568e-04 ... 5.99760620e-04
    1.09536959e-04 2.36782784e-04]
   [5.48810582e-04 5.95515012e-04 1.01639107e-05 ... 4.58573428e-04
    1.99201400e-04 1.57722796e-04]
   [5.68651769e-04 6.90615096e-04 1.29902255e-05 ... 5.34697960e-04
    1.89655431e-04 2.09470105e-04]
   ...
   [7.47851445e-04 9.04534827e-04 1.35203618e-06 ... 5.32571867e-04
    2.41931426e-04 1.64260797e-04]
   [9.08354763e-04 1.04764069e-03 1.73575245e-05 ... 5.53105550e-04
    1.93567816e-04 1.46647610e-04]
   [9.34374286e-04 1.14639569e-03 8.22859438e-05 ... 4.52392735e-04
    1.86783931e-04 1.63938501e-04]]

  [[5.56657731e-04 8.08325654e-04 3.33336327e-04 ... 6.89105305e-04
    9.89125983e-05 3.11441749e-04]
   [8.84150970e-04 9.96290939e-04 3.73240298e-04 ... 4.02954145e-04
    2.94589467e-04 3.90401401e-04]
   [9.93554713e-04 9.62707447e-04 5.24088508e-04 ... 5.99620340e-04
    2.55335093e-04 4.40406468e-04]
   ...
   [9.27518064e-04 1.34146889e-03 1.37920273e-04 ... 8.46391893e-04
    3.29378265e-04 5.38708351e-04]
   [1.02638360e-03 1.43929035e-03 6.06330468e-05 ... 8.69352836e-04
    2.90209922e-04 4.93286992e-04]
   [9.39709542e-04 1.36330479e-03 9.60405305e-05 ... 6.75685180e-04
    2.36389606e-04 3.15661775e-04]]]


 [[[1.56408452e-04 8.60906657e-05 2.45589996e-04 ... 2.99894164e-04
    1.58106530e-04 6.20882420e-05]
   [1.51653963e-04 1.05672603e-04 2.54451588e-04 ... 3.06682661e-04
    1.63726669e-04 5.45385556e-05]
   [1.52385197e-04 1.12267480e-04 2.70635268e-04 ... 3.14036501e-04
    1.70418323e-04 6.45255568e-05]
   ...
   [1.66914877e-04 1.28767817e-04 2.94387370e-04 ... 3.39201302e-04
    1.42578414e-04 5.29897225e-05]
   [1.57425020e-04 1.29032196e-04 3.02844070e-04 ... 3.38537386e-04
    1.45095983e-04 6.13025913e-05]
   [1.82537828e-04 1.36801289e-04 2.99562584e-04 ... 3.38142330e-04
    1.54873647e-04 6.14306919e-05]]

  [[3.92404851e-04 4.02827427e-04 1.76463276e-04 ... 4.20205877e-04
    5.28812961e-05 1.49793312e-04]
   [3.82949598e-04 4.22527548e-04 1.43619458e-04 ... 4.35577211e-04
    4.48260798e-05 1.48296982e-04]
   [3.96605988e-04 4.38760791e-04 1.25855746e-04 ... 4.70305356e-04
    4.96144166e-05 1.71426364e-04]
   ...
   [4.14760667e-04 4.68180922e-04 1.00599849e-04 ... 5.09189675e-04
    7.90524864e-05 1.81157608e-04]
   [4.64254932e-04 5.25650685e-04 1.29731765e-04 ... 5.07015502e-04
    1.05328996e-04 1.96792855e-04]
   [4.32733621e-04 5.38998225e-04 1.44492864e-04 ... 4.16888215e-04
    1.09614688e-04 1.55737289e-04]]

  [[6.29924005e-04 8.75293161e-04 1.65758265e-05 ... 4.00306220e-04
    2.50743411e-04 1.84439763e-04]
   [6.40036131e-04 8.94165249e-04 1.33943504e-05 ... 3.89037305e-04
    2.33151586e-04 1.94029039e-04]
   [6.61187747e-04 9.06503818e-04 1.22573747e-05 ... 3.73224699e-04
    2.01135626e-04 2.03824849e-04]
   ...
   [7.10445864e-04 9.74868773e-04 1.05954183e-04 ... 4.10941197e-04
    1.77134920e-04 2.77729094e-04]
   [7.22330064e-04 9.87224630e-04 1.46408041e-04 ... 4.89793834e-04
    2.27922908e-04 2.21888928e-04]
   [7.47129729e-04 8.95425968e-04 1.12158690e-04 ... 4.96434106e-04
    2.78070278e-04 2.01863164e-04]]

  ...

  [[1.01399631e-03 9.26075212e-04 3.55505850e-04 ... 4.93827160e-04
    1.90171981e-04 0.00000000e+00]
   [9.15840676e-04 9.74931405e-04 3.47134657e-04 ... 4.86112753e-04
    2.16427754e-04 0.00000000e+00]
   [9.45691892e-04 9.70017281e-04 3.42747051e-04 ... 5.06384182e-04
    2.05761229e-04 0.00000000e+00]
   ...
   [1.02776603e-03 1.08239020e-03 2.99930514e-04 ... 5.03655174e-04
    2.34171297e-04 0.00000000e+00]
   [1.07171398e-03 1.11420394e-03 3.16932623e-04 ... 5.20288188e-04
    2.87819072e-04 0.00000000e+00]
   [1.02422549e-03 1.16147648e-03 2.71256024e-04 ... 5.07929246e-04
    2.77252315e-04 0.00000000e+00]]

  [[4.92547872e-04 7.47851445e-04 1.44399091e-05 ... 4.62525495e-04
    2.30913283e-04 3.47128924e-04]
   [5.03479561e-04 7.81655544e-04 6.79127697e-05 ... 5.22462185e-04
    2.52625003e-04 3.83587001e-04]
   [5.62343688e-04 8.03932780e-04 9.96535309e-05 ... 5.33842423e-04
    2.72733072e-04 3.98322300e-04]
   ...
   [6.84298109e-04 8.80136329e-04 1.01809710e-04 ... 5.38182911e-04
    2.34777865e-04 3.69507412e-04]
   [6.70528505e-04 8.84762267e-04 4.34004978e-05 ... 5.38884546e-04
    2.23821233e-04 3.22921318e-04]
   [6.42941275e-04 8.57662060e-04 2.17795132e-05 ... 5.47378906e-04
    2.29498386e-04 3.15510581e-04]]

  [[1.03876845e-03 1.10945990e-03 5.40088280e-04 ... 5.48317737e-04
    3.76438722e-04 3.14492034e-04]
   [1.19622424e-03 9.71896457e-04 5.33628685e-04 ... 6.36384124e-04
    3.25021858e-04 3.06604605e-04]
   [1.24586164e-03 1.02527323e-03 5.89239411e-04 ... 6.42273924e-04
    2.19378431e-04 3.13747674e-04]
   ...
   [9.92287998e-04 1.27466465e-03 3.39535647e-04 ... 7.79865019e-04
    2.89358199e-04 5.87027986e-04]
   [1.14176655e-03 1.29631220e-03 2.14680273e-04 ... 9.47466353e-04
    3.03222303e-04 5.65312686e-04]
   [1.31252268e-03 1.41866598e-03 2.90135562e-04 ... 1.04595663e-03
    3.01967870e-04 5.18647081e-04]]]


 [[[7.10854481e-04 7.37991475e-04 2.93421763e-04 ... 5.88261173e-04
    1.28440588e-04 1.06910578e-04]
   [7.97449786e-04 8.57132662e-04 1.27664069e-04 ... 3.66285269e-04
    2.46460229e-04 1.26524363e-04]
   [8.10601516e-04 9.32887022e-04 1.34041737e-04 ... 4.16607392e-04
    2.35143758e-04 1.35113209e-04]
   ...
   [8.67621566e-04 1.00893085e-03 1.33566471e-04 ... 4.60331881e-04
    1.96246008e-04 6.91997557e-05]
   [8.77299171e-04 1.02991890e-03 1.62051467e-04 ... 4.85595054e-04
    2.05526332e-04 7.28754676e-05]
   [8.85637710e-04 1.07449258e-03 1.75801484e-04 ... 4.41888376e-04
    1.95041473e-04 1.66384620e-04]]

  [[9.22458770e-04 8.42583540e-04 4.45840415e-04 ... 4.81147639e-04
    1.51257045e-04 8.94362529e-05]
   [1.02397043e-03 9.43419931e-04 3.63497034e-04 ... 3.67190427e-04
    2.87128176e-04 8.90067968e-05]
   [1.04092190e-03 9.84147773e-04 4.04819730e-04 ... 3.51458788e-04
    2.60016532e-04 6.95883500e-05]
   ...
   [1.07681635e-03 1.05001824e-03 3.78916826e-04 ... 4.46891529e-04
    3.11385375e-04 1.20369201e-04]
   [1.05501851e-03 1.11514796e-03 3.76575364e-04 ... 4.05666273e-04
    2.83890055e-04 1.21962890e-04]
   [1.03215850e-03 1.06852397e-03 2.99386244e-04 ... 4.82639414e-04
    3.38290236e-04 1.89962855e-04]]

  [[8.88748211e-04 7.10604130e-04 3.66914581e-04 ... 5.21129114e-04
    1.42186618e-04 4.68612006e-06]
   [9.49139358e-04 8.41435511e-04 3.04429705e-04 ... 4.34760936e-04
    2.29241501e-04 0.00000000e+00]
   [9.79883596e-04 8.64936213e-04 3.02223780e-04 ... 4.58363007e-04
    2.01799805e-04 0.00000000e+00]
   ...
   [9.61827463e-04 9.77515709e-04 3.02839093e-04 ... 5.08928788e-04
    1.87764090e-04 0.00000000e+00]
   [9.95346578e-04 1.06978929e-03 3.15604790e-04 ... 4.98387555e-04
    1.99532253e-04 0.00000000e+00]
   [1.05671899e-03 1.11638138e-03 2.92386016e-04 ... 5.07155957e-04
    2.68453558e-04 0.00000000e+00]]

  ...

  [[5.59162581e-04 7.34351750e-04 3.12091666e-04 ... 6.16896141e-04
    1.52331180e-04 2.80045700e-04]
   [7.38079369e-04 8.88137321e-04 1.88482634e-04 ... 3.63229337e-04
    2.31589613e-04 3.22287349e-04]
   [8.33999366e-04 8.94319674e-04 2.87631410e-04 ... 4.45750571e-04
    2.53927603e-04 3.64456297e-04]
   ...
   [7.92206964e-04 9.41434701e-04 3.24853085e-04 ... 5.36050880e-04
    2.64650473e-04 4.08467255e-04]
   [8.00096022e-04 9.80016310e-04 2.83723319e-04 ... 6.25855057e-04
    2.55269202e-04 4.55779169e-04]
   [7.89348269e-04 1.06588798e-03 1.95560366e-04 ... 6.72260765e-04
    2.29951838e-04 4.53842542e-04]]

  [[9.83987935e-04 9.88321030e-04 4.54014225e-04 ... 6.33893418e-04
    1.43693251e-04 1.51177985e-04]
   [1.05111906e-03 1.15759298e-03 3.96282645e-04 ... 4.37213952e-04
    4.00508347e-04 2.22063216e-04]
   [1.02247356e-03 1.12440821e-03 4.80275397e-04 ... 4.70426341e-04
    4.39552648e-04 2.49174103e-04]
   ...
   [1.34357845e-03 1.05180871e-03 6.26424560e-04 ... 6.35378063e-04
    2.40585825e-04 3.37053061e-04]
   [1.10454625e-03 1.45997375e-03 5.35874628e-04 ... 7.84392003e-04
    3.37503589e-04 6.67511718e-04]
   [1.32705329e-03 1.40124734e-03 3.36048077e-04 ... 1.06697250e-03
    3.90748377e-04 9.65566840e-04]]

  [[1.08476111e-03 1.00947835e-03 6.09279727e-04 ... 4.60833369e-04
    1.67370192e-04 8.19877605e-05]
   [1.08651770e-03 1.07537699e-03 4.65534744e-04 ... 4.77025082e-04
    3.11424199e-04 1.08006076e-04]
   [9.67197528e-04 9.13763302e-04 3.08252551e-04 ... 6.70491892e-04
    3.39718477e-04 1.29134569e-04]
   ...
   [1.16076181e-03 1.04673347e-03 6.10144809e-04 ... 7.78044399e-04
    5.31244499e-04 3.45597509e-04]
   [1.50903140e-03 1.44378655e-03 7.74317828e-04 ... 7.83719937e-04
    7.09378859e-04 5.30979130e-04]
   [1.63699570e-03 1.56252272e-03 2.82912195e-04 ... 8.00483394e-04
    9.19496117e-04 9.87315550e-04]]]


 ...


 [[[1.46032180e-04 8.07877711e-07 1.38469652e-04 ... 5.67830342e-04
    1.16877309e-04 3.30004259e-05]
   [2.28246514e-04 7.34178029e-05 1.96369307e-04 ... 3.53338750e-04
    1.14852737e-04 4.03401573e-05]
   [2.74009217e-04 1.09459885e-04 2.07014702e-04 ... 3.85100167e-04
    1.40693184e-04 4.69509978e-05]
   ...
   [2.72217963e-04 1.88066027e-04 2.12719955e-04 ... 3.99237411e-04
    1.69784922e-04 6.86817802e-05]
   [2.70474469e-04 2.03346164e-04 2.31856029e-04 ... 4.03760670e-04
    1.62354394e-04 6.64896725e-05]
   [2.83681438e-04 2.20964648e-04 2.30307312e-04 ... 4.18091251e-04
    1.44668476e-04 6.26747642e-05]]

  [[4.01134486e-04 1.97723028e-04 1.80338466e-04 ... 7.83512951e-04
    1.60644413e-04 1.67284437e-04]
   [4.39298135e-04 6.08261034e-04 1.58398238e-04 ... 3.84922780e-04
    1.82812975e-04 1.45175276e-04]
   [5.12185274e-04 6.66685693e-04 3.12669727e-05 ... 4.00235847e-04
    1.73557128e-04 1.34752670e-04]
   ...
   [5.49835619e-04 7.46816397e-04 6.09966082e-06 ... 4.61395597e-04
    2.22302930e-04 1.68813145e-04]
   [5.89427713e-04 7.77186302e-04 1.15125231e-05 ... 4.87582904e-04
    2.18989837e-04 2.45270989e-04]
   [6.04299537e-04 8.05897813e-04 6.02933724e-05 ... 5.04005991e-04
    1.71660009e-04 2.65256851e-04]]

  [[6.45998283e-04 4.32915316e-04 2.10305661e-04 ... 7.77345907e-04
    1.69488907e-04 2.81430024e-04]
   [9.68809589e-04 9.16353136e-04 4.18833719e-04 ... 3.34604905e-04
    2.77157000e-04 1.19505101e-04]
   [1.04461017e-03 9.84568615e-04 3.80305835e-04 ... 3.95589392e-04
    2.80361361e-04 9.99324038e-05]
   ...
   [1.07368326e-03 1.04516454e-03 4.06285282e-04 ... 4.66662314e-04
    3.15020530e-04 9.02679458e-05]
   [1.08243106e-03 1.00685866e-03 3.71493603e-04 ... 4.77302703e-04
    2.38743654e-04 4.40952972e-05]
   [1.00512954e-03 1.06421777e-03 2.91105505e-04 ... 4.98967420e-04
    2.94698460e-04 9.99525946e-05]]

  ...

  [[6.43063278e-04 4.22982324e-04 2.28952063e-04 ... 8.50326614e-04
    1.86712321e-04 3.24516965e-04]
   [9.08785558e-04 1.07387768e-03 3.81416350e-04 ... 3.93914321e-04
    2.99593958e-04 3.73527699e-04]
   [9.80980694e-04 1.11937535e-03 4.02629230e-04 ... 4.11253888e-04
    3.87385749e-04 4.17925301e-04]
   ...
   [1.26358867e-03 9.97338211e-04 5.97186270e-04 ... 6.32164592e-04
    2.74944992e-04 4.54210414e-04]
   [1.17291848e-03 1.25439535e-03 5.50888944e-04 ... 6.93312322e-04
    2.69092998e-04 4.86040517e-04]
   [1.06682535e-03 1.29196281e-03 3.17708182e-04 ... 8.22430360e-04
    3.09882831e-04 6.32431882e-04]]

  [[7.70196435e-04 4.60048585e-04 1.78739021e-04 ... 5.14496875e-04
    3.42661602e-04 3.96112358e-04]
   [1.08232303e-03 9.34863172e-04 5.76998573e-04 ... 4.58521274e-04
    2.99240288e-04 6.90010565e-05]
   [1.01821427e-03 9.26145003e-04 4.08783351e-04 ... 5.69504162e-04
    5.10575832e-04 1.61671269e-04]
   ...
   [1.03875133e-03 9.75624658e-04 3.42536950e-04 ... 6.93304581e-04
    6.05744659e-04 9.89230757e-05]
   [1.53240305e-03 9.72798443e-04 5.35362109e-04 ... 8.11267702e-04
    4.84509190e-04 9.39876481e-05]
   [1.45661819e-03 1.55647611e-03 8.58778018e-04 ... 6.81465724e-04
    8.85112386e-04 4.52759501e-04]]

  [[4.52078064e-04 1.03278900e-04 5.46023693e-05 ... 3.47232562e-04
    1.62081036e-04 1.17608150e-04]
   [5.59847278e-04 2.55742547e-04 8.26897449e-05 ... 4.20462893e-04
    2.33823986e-04 0.00000000e+00]
   [6.14149903e-04 3.36373021e-04 1.94526627e-04 ... 5.47973905e-04
    2.60451052e-04 0.00000000e+00]
   ...
   [7.19495467e-04 4.81806404e-04 2.12993778e-04 ... 5.41406916e-04
    2.54974992e-04 0.00000000e+00]
   [9.88315092e-04 4.22145182e-04 2.19815993e-04 ... 5.94702375e-04
    3.39266815e-04 0.00000000e+00]
   [7.75444380e-04 5.27048309e-04 1.62256096e-04 ... 4.42705757e-04
    2.84789770e-04 0.00000000e+00]]]


 [[[6.63125073e-04 8.16294865e-04 1.18979922e-04 ... 3.61273356e-04
    2.08943908e-04 3.75401461e-04]
   [6.87746680e-04 8.28329765e-04 1.58126044e-04 ... 3.99342884e-04
    2.33858533e-04 4.17474541e-04]
   [7.44597288e-04 8.43887799e-04 1.74501765e-04 ... 4.30012733e-04
    2.52876373e-04 4.21988952e-04]
   ...
   [7.84014352e-04 8.97379185e-04 2.63651891e-04 ... 4.58932831e-04
    2.57955486e-04 4.34779620e-04]
   [7.39281997e-04 8.97741294e-04 2.69548094e-04 ... 4.94365057e-04
    2.53883394e-04 4.57990915e-04]
   [7.81960669e-04 8.92954646e-04 2.84024514e-04 ... 5.11906459e-04
    2.63854279e-04 4.63571603e-04]]

  [[8.44484137e-04 9.96576622e-04 2.92343029e-04 ... 3.71764996e-04
    2.92000535e-04 2.98904139e-04]
   [8.64055473e-04 9.86386673e-04 3.54333286e-04 ... 4.07565240e-04
    2.90654047e-04 3.27297952e-04]
   [9.35498043e-04 9.84020764e-04 4.36717150e-04 ... 4.52545588e-04
    2.81152374e-04 3.47133406e-04]
   ...
   [9.69682762e-04 9.15489567e-04 5.14731801e-04 ... 5.47468022e-04
    2.01346527e-04 3.24325403e-04]
   [9.39500111e-04 9.29053233e-04 4.56701935e-04 ... 5.90210315e-04
    1.74034794e-04 3.38222890e-04]
   [8.54231825e-04 1.01723487e-03 4.05240571e-04 ... 5.83253102e-04
    2.34574618e-04 3.59822821e-04]]

  [[9.76243871e-04 1.11528917e-03 4.01948462e-04 ... 4.08625987e-04
    3.43210064e-04 2.22256480e-04]
   [9.71455243e-04 1.09257875e-03 4.08815540e-04 ... 3.99400655e-04
    3.84025712e-04 1.92512351e-04]
   [9.78455413e-04 1.10306230e-03 4.84157674e-04 ... 4.57645860e-04
    3.77066550e-04 2.39838511e-04]
   ...
   [1.16371666e-03 9.73625050e-04 5.44813811e-04 ... 6.32302777e-04
    3.65382468e-04 2.05583710e-04]
   [1.25537557e-03 9.94917704e-04 5.93954232e-04 ... 6.51662413e-04
    1.97476242e-04 2.15958862e-04]
   [1.26810186e-03 1.05599326e-03 5.98107057e-04 ... 6.39658945e-04
    1.73739390e-04 2.92974146e-04]]

  ...

  [[1.04716222e-03 1.03224849e-03 4.13265545e-04 ... 5.45349612e-04
    3.09587456e-04 1.54657446e-05]
   [9.73114860e-04 9.32716299e-04 3.42349231e-04 ... 5.79872052e-04
    2.64403236e-04 0.00000000e+00]
   [9.42885294e-04 8.64110887e-04 2.92415498e-04 ... 6.70159177e-04
    1.99210423e-04 0.00000000e+00]
   ...
   [9.36210097e-04 1.00371649e-03 3.70315305e-04 ... 7.08802196e-04
    4.23880963e-04 2.64828559e-05]
   [1.04664383e-03 1.04641472e-03 5.50359720e-04 ... 7.43164332e-04
    3.68925917e-04 1.50104344e-04]
   [1.13494636e-03 9.35503514e-04 4.99908230e-04 ... 5.94375655e-04
    4.64240467e-04 2.71167228e-04]]

  [[1.16281048e-03 7.97411718e-04 5.39824716e-04 ... 4.31497872e-04
    2.82848574e-04 3.36937548e-04]
   [1.12968171e-03 5.93967619e-04 4.36897244e-04 ... 5.15047926e-04
    3.58606951e-04 3.92925984e-04]
   [1.08200079e-03 6.03850465e-04 2.96960032e-04 ... 6.36938319e-04
    5.80412452e-04 4.15799179e-04]
   ...
   [1.03967765e-03 9.59517783e-04 3.67904984e-04 ... 7.17084855e-04
    6.26586145e-04 1.90826409e-04]
   [9.77858319e-04 9.73211427e-04 2.03299656e-04 ... 6.36947865e-04
    5.85725065e-04 8.13906081e-05]
   [1.22413563e-03 9.21816274e-04 1.71831576e-04 ... 5.74408332e-04
    6.44093612e-04 4.47476996e-05]]

  [[1.01846212e-03 5.09886711e-04 4.40791628e-04 ... 3.07506212e-04
    4.02505422e-04 1.50235297e-04]
   [9.59796365e-04 3.82200087e-04 4.21057572e-04 ... 3.07357492e-04
    4.88732243e-04 2.65923911e-04]
   [9.63180617e-04 5.44764451e-04 2.12207291e-04 ... 4.13225032e-04
    5.36678301e-04 3.26671201e-04]
   ...
   [9.14110278e-04 7.76315457e-04 2.15325359e-04 ... 5.06968703e-04
    5.55260805e-04 1.78733157e-04]
   [1.01879437e-03 7.87431432e-04 1.06183710e-04 ... 4.61999909e-04
    6.18940685e-04 1.10073961e-04]
   [1.23866485e-03 7.13095302e-04 2.55639869e-04 ... 5.35024737e-04
    6.37914229e-04 3.81376485e-05]]]


 [[[4.61888034e-04 6.09130948e-04 0.00000000e+00 ... 4.93708998e-04
    1.68366809e-04 1.62380602e-04]
   [4.92671737e-04 6.76792930e-04 0.00000000e+00 ... 5.35223517e-04
    1.66602593e-04 1.49476560e-04]
   [5.31188794e-04 7.61642703e-04 0.00000000e+00 ... 5.35372470e-04
    1.89130762e-04 1.55119531e-04]
   ...
   [6.22987282e-04 8.98912840e-04 0.00000000e+00 ... 5.50456520e-04
    1.41620010e-04 1.29751556e-04]
   [6.30784256e-04 8.85111920e-04 0.00000000e+00 ... 5.43296337e-04
    1.62976692e-04 1.04524253e-04]
   [6.24689157e-04 9.15140612e-04 0.00000000e+00 ... 5.38314052e-04
    1.87846032e-04 8.55938124e-05]]

  [[4.69201856e-04 7.28031329e-04 2.05415417e-05 ... 4.27531253e-04
    2.13629421e-04 3.17389669e-04]
   [5.03153948e-04 7.58954964e-04 5.64207985e-05 ... 4.76691319e-04
    2.35648506e-04 3.33204429e-04]
   [5.89594827e-04 8.02538125e-04 1.05187224e-04 ... 4.92074818e-04
    2.46964366e-04 3.32212658e-04]
   ...
   [6.42300118e-04 8.97333608e-04 4.08591695e-05 ... 5.36173815e-04
    2.01701201e-04 2.82717432e-04]
   [6.20675622e-04 8.76816630e-04 1.02843187e-05 ... 5.24186005e-04
    1.96357287e-04 2.63707538e-04]
   [6.62022561e-04 9.09602793e-04 0.00000000e+00 ... 5.29554265e-04
    2.22179951e-04 2.29623605e-04]]

  [[6.95348077e-04 8.56250233e-04 1.80514733e-04 ... 4.31301509e-04
    2.52090627e-04 4.10353736e-04]
   [6.71692251e-04 8.62184155e-04 2.11321865e-04 ... 4.72889660e-04
    2.58270884e-04 4.39681229e-04]
   [7.56172871e-04 8.66277900e-04 2.45115807e-04 ... 4.93777858e-04
    2.74741120e-04 4.41300159e-04]
   ...
   [7.72296567e-04 9.65878484e-04 1.65165053e-04 ... 5.57260821e-04
    2.34145569e-04 4.35704482e-04]
   [7.22607656e-04 9.22873558e-04 1.17526535e-04 ... 5.59279462e-04
    2.91681325e-04 4.78223315e-04]
   [7.39462324e-04 9.54916410e-04 6.71359157e-05 ... 5.93038974e-04
    3.24369408e-04 4.66646510e-04]]

  ...

  [[1.02537055e-03 9.80263343e-04 5.34047955e-04 ... 5.84952184e-04
    2.61863810e-04 2.27753306e-04]
   [1.02480117e-03 9.68898588e-04 4.98235691e-04 ... 6.19755650e-04
    1.47410174e-04 2.37508968e-04]
   [9.51931288e-04 1.08102430e-03 4.34072805e-04 ... 6.05455483e-04
    2.36376829e-04 3.60293896e-04]
   ...
   [8.78211518e-04 1.13720179e-03 2.07752179e-04 ... 7.94515479e-04
    1.95530200e-04 3.52448813e-04]
   [9.91355279e-04 1.21842849e-03 1.27287640e-04 ... 8.85642017e-04
    1.93387925e-04 4.44241101e-04]
   [1.00810151e-03 1.44348270e-03 8.58277417e-05 ... 9.08978516e-04
    3.85017571e-04 5.89102332e-04]]

  [[1.11832179e-03 1.03061448e-03 5.60064684e-04 ... 6.56276417e-04
    3.78405442e-04 2.32131119e-04]
   [1.26035057e-03 1.00981991e-03 5.95771940e-04 ... 6.64868508e-04
    2.16468965e-04 2.26243021e-04]
   [1.24655105e-03 1.11724716e-03 5.87466173e-04 ... 6.42505824e-04
    1.83350872e-04 3.45389155e-04]
   ...
   [1.08626939e-03 1.25130103e-03 2.87671603e-04 ... 8.62188521e-04
    3.10534437e-04 6.18889462e-04]
   [1.36575825e-03 1.27017160e-03 2.86309572e-04 ... 1.05184119e-03
    2.92746670e-04 5.58924396e-04]
   [1.33894070e-03 1.68739317e-03 3.13424243e-04 ... 1.04551692e-03
    4.44009434e-04 6.91532739e-04]]

  [[1.10974570e-03 1.01111422e-03 5.01895964e-04 ... 4.97402332e-04
    4.15584334e-04 1.78697956e-04]
   [1.34606811e-03 9.95704671e-04 6.41738065e-04 ... 6.46184897e-04
    3.00126034e-04 3.00160900e-04]
   [1.33882579e-03 1.06573268e-03 6.72655238e-04 ... 6.54346484e-04
    3.10959498e-04 4.16736177e-04]
   ...
   [1.22677570e-03 1.37567311e-03 4.13820089e-04 ... 8.72149481e-04
    5.37466025e-04 9.69339977e-04]
   [1.74198754e-03 1.41469191e-03 4.68753773e-04 ... 1.10176951e-03
    5.04795811e-04 7.42289994e-04]
   [1.81470963e-03 1.70377118e-03 6.32593466e-04 ... 9.86305764e-04
    4.76391579e-04 7.20137963e-04]]]], shape=(256, 7, 7, 512), dtype=float32)

 通過ROI Pooling後是將裁剪的區域輸入全連接層進行目標分類和座標迴歸

# 經過兩個全連接層
fc7 = vgg16.head_to_tail(pool5)
""" 預測最終每個roi的類別概率, 邊框bbox
        """
classes = ['__background__', 'bird', 'cat', 'cow', 'dog', 'horse', 'sheep', 'aeroplane',
           'bicycle', 'boat', 'bus', 'car', 'motorbike', 'train', 'bottle', 'chair',
           'diningtable', 'pottedplant', 'sofa', 'tvmonitor', 'person']
# 多任務網絡,輸出到分類全連接層
cls_score = layers.Dense(units=len(classes), kernel_regularizer='l2')(fc7)
cls_prob = layers.Softmax(name='cls_prob')(cls_score)
print('cls_prob', cls_prob)
# 多任務網絡,輸出到位置座標全連接層
bbox_pred = layers.Dense(units=len(classes) * 4, kernel_regularizer='l2')(fc7)
print('bbox_pred', bbox_pred)

運行結果

cls_prob tf.Tensor(
[[0.04764251 0.04766487 0.04761864 ... 0.04762166 0.04763    0.04765202]
 [0.04763313 0.04763036 0.04762237 ... 0.04762399 0.04763886 0.04765267]
 [0.0476309  0.04762912 0.04761694 ... 0.04762627 0.04764616 0.04766076]
 ...
 [0.04762792 0.04768769 0.04761653 ... 0.04763751 0.0476356  0.04767268]
 [0.04763427 0.04762722 0.047622   ... 0.04761866 0.04763188 0.0476636 ]
 [0.04763592 0.04767709 0.04761501 ... 0.0476355  0.04765055 0.04766272]], shape=(256, 21), dtype=float32)
bbox_pred tf.Tensor(
[[ 1.9778885e-04 -3.3251714e-04 -4.9960869e-04 ... -6.1972119e-04
  -1.0668292e-03 -1.7538434e-05]
 [ 1.3820577e-04 -5.8698573e-04 -6.4890028e-04 ... -5.5852829e-04
  -7.8546786e-04  6.5681414e-04]
 [-3.1575924e-05 -5.3623994e-04 -6.7997043e-04 ... -2.4759545e-04
  -1.1621144e-03  1.8989285e-04]
 ...
 [ 4.4140400e-04 -4.2886240e-04 -7.0175709e-04 ... -3.4246917e-04
  -9.4860920e-04 -8.4239960e-05]
 [ 2.6616781e-05 -4.6439210e-04 -5.2275538e-04 ... -2.9216046e-04
  -5.0384674e-04  4.9830112e-04]
 [ 3.7231349e-04 -5.3368503e-04 -3.2786035e-04 ... -2.0706654e-04
  -1.4091514e-03  5.8860541e-04]], shape=(256, 84), dtype=float32)

然後是搭建網絡模型

predictions['cls_score'] = cls_score
predictions['cls_prob'] = cls_prob
predictions['bbox_pred'] = bbox_pred
# 建立模型,輸入爲圖片和標註框和標註分類
# 輸出爲RPN網絡的各項輸出以及ROI Pooling的多任務網絡的輸出
model = models.Model(inputs=[im_inputs, gt_boxes],
                     outputs=[anchor_targets, proposal_targets, predictions, cls_prob, bbox_pred])

此時我們對源碼本身進行網絡結構打印

print(model.summary())

運行結果

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(1, None, None, 3)] 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (1, None, None, 64)  1792        input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (1, None, None, 64)  36928       conv2d[0][0]                     
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)    (1, None, None, 64)  0           conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (1, None, None, 128) 73856       max_pooling2d[0][0]              
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (1, None, None, 128) 147584      conv2d_2[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (1, None, None, 128) 0           conv2d_3[0][0]                   
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (1, None, None, 256) 295168      max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (1, None, None, 256) 590080      conv2d_4[0][0]                   
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (1, None, None, 256) 590080      conv2d_5[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (1, None, None, 256) 0           conv2d_6[0][0]                   
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (1, None, None, 512) 1180160     max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (1, None, None, 512) 2359808     conv2d_7[0][0]                   
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (1, None, None, 512) 2359808     conv2d_8[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (1, None, None, 512) 0           conv2d_9[0][0]                   
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (1, None, None, 512) 2359808     max_pooling2d_3[0][0]            
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (1, None, None, 512) 2359808     conv2d_10[0][0]                  
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (1, None, None, 512) 2359808     conv2d_11[0][0]                  
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (1, None, None, 512) 2359808     conv2d_12[0][0]                  
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (1, None, None, 18)  9234        conv2d_13[0][0]                  
__________________________________________________________________________________________________
tf.compat.v1.shape_2 (TFOpLambd (4,)                 0           conv2d_12[0][0]                  
__________________________________________________________________________________________________
tf.compat.v1.shape_3 (TFOpLambd (4,)                 0           conv2d_12[0][0]                  
__________________________________________________________________________________________________
tf.compat.v1.shape (TFOpLambda) (4,)                 0           input_1[0][0]                    
__________________________________________________________________________________________________
tf.compat.v1.shape_1 (TFOpLambd (4,)                 0           input_1[0][0]                    
__________________________________________________________________________________________________
tf.reshape (TFOpLambda)         (None, 2)            0           conv2d_14[0][0]                  
__________________________________________________________________________________________________
tf.__operators__.getitem_2 (Sli ()                   0           tf.compat.v1.shape_2[0][0]       
__________________________________________________________________________________________________
tf.__operators__.getitem_3 (Sli ()                   0           tf.compat.v1.shape_3[0][0]       
__________________________________________________________________________________________________
tf.__operators__.getitem (Slici ()                   0           tf.compat.v1.shape[0][0]         
__________________________________________________________________________________________________
tf.__operators__.getitem_1 (Sli ()                   0           tf.compat.v1.shape_1[0][0]       
__________________________________________________________________________________________________
softmax (Softmax)               (None, 2)            0           tf.reshape[0][0]                 
__________________________________________________________________________________________________
tf.compat.v1.shape_4 (TFOpLambd (4,)                 0           conv2d_14[0][0]                  
__________________________________________________________________________________________________
generate_anchors (GenerateAncho ((None, 4), ())      0           tf.__operators__.getitem_2[0][0] 
                                                                 tf.__operators__.getitem_3[0][0] 
__________________________________________________________________________________________________
input_2 (InputLayer)            [(1, None, 5)]       0                                            
__________________________________________________________________________________________________
tf.cast (TFOpLambda)            (2,)                 0           tf.__operators__.getitem[0][0]   
                                                                 tf.__operators__.getitem_1[0][0] 
__________________________________________________________________________________________________
tf.reshape_2 (TFOpLambda)       (1, None, None, 18)  0           softmax[0][0]                    
                                                                 tf.compat.v1.shape_4[0][0]       
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (1, None, None, 36)  18468       conv2d_13[0][0]                  
__________________________________________________________________________________________________
proposal_layer (ProposalLayer)  ((None, 5), (None, 1 0           tf.reshape_2[0][0]               
                                                                 generate_anchors[0][0]           
                                                                 tf.cast[0][0]                    
                                                                 conv2d_15[0][0]                  
__________________________________________________________________________________________________
tf.__operators__.getitem_5 (Sli (None, 5)            0           input_2[0][0]                    
__________________________________________________________________________________________________
proposal_target_layer (Proposal ((None, 5), (None,), 0           proposal_layer[0][0]             
                                                                 tf.__operators__.getitem_5[0][0] 
                                                                 proposal_layer[0][1]             
__________________________________________________________________________________________________
crop_pool_layer (CropPoolLayer) (None, 7, 7, 512)    0           conv2d_12[0][0]                  
                                                                 tf.cast[0][0]                    
                                                                 proposal_target_layer[0][0]      
__________________________________________________________________________________________________
flatten (Flatten)               (None, 25088)        0           crop_pool_layer[0][0]            
__________________________________________________________________________________________________
dense (Dense)                   (None, 4096)         102764544   flatten[0][0]                    
__________________________________________________________________________________________________
dropout (Dropout)               (None, 4096)         0           dense[0][0]                      
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 4096)         16781312    dropout[0][0]                    
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 4096)         0           dense_1[0][0]                    
__________________________________________________________________________________________________
tf.__operators__.getitem_4 (Sli (None, 5)            0           input_2[0][0]                    
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 21)           86037       dropout_1[0][0]                  
__________________________________________________________________________________________________
tf.reshape_1 (TFOpLambda)       (None, 2)            0           tf.reshape[0][0]                 
__________________________________________________________________________________________________
anchor_target_layer (AnchorTarg ((1, None, None, 9), 0           conv2d_14[0][0]                  
                                                                 generate_anchors[0][0]           
                                                                 tf.__operators__.getitem_4[0][0] 
                                                                 tf.cast[0][0]                    
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 84)           344148      dropout_1[0][0]                  
__________________________________________________________________________________________________
cls_prob (Softmax)              (None, 21)           0           dense_2[0][0]                    
__________________________________________________________________________________________________
tf.math.argmax (TFOpLambda)     (None,)              0           tf.reshape_1[0][0]               
==================================================================================================
Total params: 137,078,239
Trainable params: 137,078,239
Non-trainable params: 0
__________________________________________________________________________________________________
None

如果是非訓練的預測網絡爲

else:
    rois, roi_scores = region_proposal_network(conv_net=feature_map,
                                               anchors=anchors,
                                               gt_boxes=gt_boxes,
                                               im_info=im_info,
                                               is_training=is_training)
    # 後卷積池化全連接, 第二次bbox迴歸, class分類
    """ 裁剪層, 對卷積網絡層輸出的特徵, 根據rpn層輸出的roi進行裁剪, 且resize到統一的大小

                    :return [bbox_nums, pre_pool_size, pre_pool_size, depth]
                    """
    # 獲取批量feature map的索引,由於我們只引入了一張圖片,所以這裏都是0
    batch_ids = tf.squeeze(tf.slice(rois, [0, 0], [-1, 1]), [1])
    # feature map的高和寬
    height = im_info[0]
    width = im_info[1]
    # 獲取非極大值抑制前背景邊框在feature map上的座標
    x1 = tf.expand_dims(rois[:, 1] / width, 1)
    y1 = tf.expand_dims(rois[:, 2] / height, 1)
    x2 = tf.expand_dims(rois[:, 3] / width, 1)
    y2 = tf.expand_dims(rois[:, 4] / height, 1)
    # 組合成座標框
    bboxes = tf.concat([y1, x1, y2, x2], axis=1)
    pool_size_after_rpn = 7
    pre_pool_size = pool_size_after_rpn * 2
    # [bbox_nums, pre_pool_size, pre_pool_size, depth]
    # 將非極大值抑制的前背景邊框在feature map上進行裁剪
    crops = tf.image.crop_and_resize(image=feature_map,
                                     boxes=bboxes,
                                     box_indices=tf.cast(batch_ids, dtype=tf.int32),
                                     crop_size=[pre_pool_size, pre_pool_size])
    # 將裁剪後的區域進行下采樣
    pool5 = layers.MaxPooling2D(pool_size=(2, 2), padding='SAME')(crops)
    # 經過兩個全連接層
    fc7 = vgg16.head_to_tail(pool5)
    classes = ['__background__', 'bird', 'cat', 'cow', 'dog', 'horse', 'sheep', 'aeroplane',
               'bicycle', 'boat', 'bus', 'car', 'motorbike', 'train', 'bottle', 'chair',
               'diningtable', 'pottedplant', 'sofa', 'tvmonitor', 'person']
    # 多任務網絡,輸出到分類全連接層
    cls_score = layers.Dense(units=len(classes), kernel_regularizer='l2')(fc7)
    cls_prob = layers.Softmax(name='cls_prob')(cls_score)
    # 多任務網絡,輸出到位置座標全連接層
    bbox_pred = layers.Dense(units=len(classes) * 4, kernel_regularizer='l2')(fc7)
    # 建立模型,輸入爲圖片
    # 輸出爲ROI Pooling的多任務網絡的輸出
    model = models.Model(inputs=[im_inputs],
                         outputs=[rois, cls_prob, bbox_pred])

現在要開始進行圖片訓練的過程了,圖片的訓練格式爲VOC的數據格式

# 優化器
optimizer = optimizers.Adam(1e-5)
# 加載訓練數據
train_data_generator = DataGenerator(voc_data_path='./data/voc2012_46_samples',
                                     classes=classes,
                                     batch_size=1,
                                     feat_stride=16,
                                     train_fg_thresh=0.5,
                                     train_bg_thresh_hi=0.5,
                                     train_bg_thresh_lo=0.1)
# 保存日誌文件
summary_writer = tf.summary.create_file_writer('./logs')

這裏我們來看一下DataGenerator這個類

import os
import math
import numpy as np
import cv2
import tensorflow as tf
from data.xml_ops import xml2dict
from anchors_ops import generate_anchors_pre_tf
from bbox_ops import bbox_overlaps_tf, clip_boxes_tf


class DataGenerator:
    def __init__(self,
                 voc_data_path,
                 classes,
                 im_size=600,
                 data_max_size=15000,
                 data_max_size_per_class=350,
                 max_box_fraction=0.5,
                 batch_size=1,
                 is_training=True,
                 feat_stride=16,
                 train_fg_thresh=0.5,
                 train_bg_thresh_hi=0.5,
                 train_bg_thresh_lo=0.1,
                 pixel_mean=np.array([[[102.9801, 115.9465, 122.7717]]]),
                 is_voc_2012=False,
                 ):
        self.annotation_files_root_path = os.path.join(voc_data_path, "Annotations")
        self.img_files_root_path = os.path.join(voc_data_path, "JPEGImages")
        self.imgset_root_path = os.path.join(voc_data_path, "ImageSets", "Segmentation")
        self.max_box_fraction = max_box_fraction
        self.batch_size = batch_size
        self.is_training = is_training
        self.feat_stride = feat_stride
        self.train_fg_thresh = train_fg_thresh
        self.train_bg_thresh_hi = train_bg_thresh_hi
        self.train_bg_thresh_lo = train_bg_thresh_lo
        self.pixel_mean = pixel_mean
        self.data_max_size = data_max_size
        self.data_max_size_per_class = data_max_size_per_class
        self.im_size = im_size
        # classes = ["__background__", ...]
        self.classes = classes
        assert self.classes[0] == "__background__", "classes index 0 need to be __background__"
        self.cls_to_inds = dict(list(zip(self.classes, list(range(len(self.classes))))))
        # 初始化
        self.total_batch_size = 0
        self.img_files = []
        self.annotation_files = []
        self.current_batch_index = 0
        self.file_indices = []
        # color map

        # 重新賦值
        self.__load_files()
        # self._on_epoch_end()
        # 過濾只包含小目標的樣本數據
        self.__filter_small_objs()
        # 過濾那些太大目標的樣本數據
        self.__filter_big_objs()
        # 平衡每個類別樣本數
        self.__balance_class_data()

    def __load_files(self):
        if self.is_training:
            file = os.path.join(self.imgset_root_path, "trainval.txt")
        else:
            file = os.path.join(self.imgset_root_path, 'test.txt')

        img_files = []
        annotation_files = []
        with open(file, encoding='utf-8', mode='r') as f:
            data = f.readlines()
            for file_name in data:
                file_name = file_name.strip()
                img_file_jpeg = os.path.join(self.img_files_root_path, "{}.jpeg".format(file_name))
                img_file_jpg = os.path.join(self.img_files_root_path, "{}.jpg".format(file_name))
                annotation_file = os.path.join(self.annotation_files_root_path, "{}.xml".format(file_name))
                if os.path.isfile(annotation_file):
                    if os.path.isfile(img_file_jpeg):
                        img_files.append(img_file_jpeg)
                        annotation_files.append(annotation_file)
                    elif os.path.isfile(img_file_jpg):
                        img_files.append(img_file_jpg)
                        annotation_files.append(annotation_file)

            self.img_files = img_files[:self.data_max_size]
            self.annotation_files = annotation_files[:self.data_max_size]

        self.total_batch_size = int(math.floor(len(self.annotation_files) / self.batch_size))
        self.file_indices = np.arange(len(self.annotation_files))
        # np.random.shuffle(self.file_indices)

    def __filter_big_objs(self):
        """ 過濾目標太大的樣本數據 """
        filter_annotation_files = []
        filter_img_files = []

        for i in range(len(self.annotation_files)):
            # for i in self.file_indices:
            annotation = xml2dict(self.annotation_files[i])
            img_width = int(annotation['annotation']['size']['width'])
            img_height = int(annotation['annotation']['size']['height'])
            objs = annotation['annotation']['object']

            area = img_height * img_width
            keep = True
            if type(objs) == list:
                for box in objs:
                    xmin = int(float(box['bndbox']['xmin']))
                    ymin = int(float(box['bndbox']['ymin']))
                    xmax = int(float(box['bndbox']['xmax']))
                    ymax = int(float(box['bndbox']['ymax']))
                    if (ymax - ymin) * (xmax - xmin) / area > self.max_box_fraction:
                        keep = False
            else:
                xmin = int(float(objs['bndbox']['xmin']))
                ymin = int(float(objs['bndbox']['ymin']))
                xmax = int(float(objs['bndbox']['xmax']))
                ymax = int(float(objs['bndbox']['ymax']))
                if (ymax - ymin) * (xmax - xmin) / area > self.max_box_fraction:
                    keep = False

            if keep:
                filter_annotation_files.append(self.annotation_files[i])
                filter_img_files.append(self.img_files[i])
            else:
                print("filter big obj file: {}, {}".format(self.annotation_files[i], self.img_files[i]))

        remove_file_nums = len(self.annotation_files) - len(filter_annotation_files)
        self.annotation_files = filter_annotation_files
        self.img_files = filter_img_files

        self.total_batch_size = int(math.floor(len(self.annotation_files) / self.batch_size))
        # self.total_batch_size = int(math.floor(len(not_filter_file_indices) / self.batch_size))
        self.file_indices = np.arange(len(self.annotation_files))
        # self.file_indices = not_filter_file_indices
        print("after filter big obj, total file nums: {}, remove {} files".format(len(self.annotation_files),
                                                                                  remove_file_nums))

    def __filter_small_objs(self):
        """ 過濾只包含小目標的樣本數據 """
        filter_annotation_files = []
        filter_img_files = []

        for i in range(len(self.annotation_files)):
            # for i in self.file_indices:
            annotation = xml2dict(self.annotation_files[i])
            print(self.annotation_files[i])
            print(annotation)

            # 預生成的anchors
            img_width = int(annotation['annotation']['size']['width'])
            img_height = int(annotation['annotation']['size']['height'])
            anchors, _ = generate_anchors_pre_tf(height=int(img_height / self.feat_stride),
                                                 width=int(img_width / self.feat_stride),
                                                 feat_stride=self.feat_stride)
            inds_inside = tf.reshape(tf.where(
                (anchors[:, 0] >= -0) &
                (anchors[:, 1] >= -0) &
                (anchors[:, 2] < (img_width + 0)) &  # width
                (anchors[:, 3] < (img_height + 0))  # height
            ), shape=(-1,))

            clip_anchors = clip_boxes_tf(anchors, [img_height, img_width])

            # gt_boxes
            objs = annotation['annotation']['object']
            boxes = []

            if type(objs) == list:
                for box in objs:
                    # cls_inds = self.cls_to_inds[box['name']]
                    xmin = int(float(box['bndbox']['xmin']))
                    ymin = int(float(box['bndbox']['ymin']))
                    xmax = int(float(box['bndbox']['xmax']))
                    ymax = int(float(box['bndbox']['ymax']))
                    boxes.append([xmin, ymin, xmax, ymax])
            else:
                # cls_inds = self.cls_to_inds[objs['name']]
                xmin = int(float(objs['bndbox']['xmin']))
                ymin = int(float(objs['bndbox']['ymin']))
                xmax = int(float(objs['bndbox']['xmax']))
                ymax = int(float(objs['bndbox']['ymax']))
                boxes.append([xmin, ymin, xmax, ymax])

            boxes = tf.cast(boxes, dtype=tf.float32)
            # 計算IoU
            overlaps = bbox_overlaps_tf(clip_anchors, boxes).numpy()
            max_overlaps = overlaps.max(axis=1)

            fg_inds = np.where(max_overlaps >= self.train_fg_thresh)[0]
            # Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI)
            bg_inds = np.where((max_overlaps < self.train_bg_thresh_hi) &
                               (max_overlaps >= self.train_bg_thresh_lo))[0]

            # image is only valid if such boxes exist
            if (len(fg_inds) > 0 or len(bg_inds) > 0) and tf.shape(inds_inside)[0] > 0:
                # not_filter_file_indices.append(i)
                filter_annotation_files.append(self.annotation_files[i])
                filter_img_files.append(self.img_files[i])
            else:
                print("filter small obj file: {}, {}".format(self.annotation_files[i], self.img_files[i]))

        remove_file_nums = len(self.annotation_files) - len(filter_annotation_files)
        self.annotation_files = filter_annotation_files
        self.img_files = filter_img_files

        self.total_batch_size = int(math.floor(len(self.annotation_files) / self.batch_size))
        # self.total_batch_size = int(math.floor(len(not_filter_file_indices) / self.batch_size))
        self.file_indices = np.arange(len(self.annotation_files))
        # self.file_indices = not_filter_file_indices
        print("after filter small obj, total file nums: {}, remove {} files".format(len(self.annotation_files),
                                                                                    remove_file_nums))

    def __balance_class_data(self):
        """ 平衡每個類別樣本數 """
        # balance_annotation_files = []
        # balance_img_files = []
        balance_file_indices = []
        per_class_nums = dict(zip(self.classes, [0] * len(self.classes)))

        for i in self.file_indices:
            annotation = xml2dict(self.annotation_files[i])
            objs = annotation['annotation']['object']

            all_classes = []
            if type(objs) == list:
                for obj in objs:
                    all_classes.append(obj['name'])
            else:
                all_classes.append(objs['name'])

            keep = False
            # if 'person' not in all_classes:
            for cls in set(all_classes):
                if per_class_nums[cls] <= self.data_max_size_per_class:
                    keep = True
                    per_class_nums[cls] += 1
            if keep:
                balance_file_indices.append(i)
                # balance_annotation_files.append(self.annotation_files[i])
                # balance_img_files.append(self.img_files[i])

        remove_file_nums = len(self.annotation_files) - len(balance_file_indices)
        # remove_file_nums = len(self.file_indices) - len(balance_file_indices)
        # self.annotation_files = balance_annotation_files
        # self.img_files = balance_img_files

        # self.total_batch_size = int(math.floor(len(self.annotation_files) / self.batch_size))
        self.total_batch_size = int(math.floor(len(balance_file_indices) / self.batch_size))
        # self.file_indices = np.arange(len(self.annotation_files))
        self.file_indices = balance_file_indices
        print("after balance total file nums: {}, remove {} files".format(len(balance_file_indices), remove_file_nums))
        print("every class nums: {}".format(per_class_nums))

    def next_batch(self):
        if self.current_batch_index >= self.total_batch_size:
            self.current_batch_index = 0
            self._on_epoch_end()

        indices = self.file_indices[self.current_batch_index * self.batch_size:
                                    (self.current_batch_index + 1) * self.batch_size]
        annotation_file = [self.annotation_files[k] for k in indices]
        print(annotation_file)
        # annotation_file = ["../../data/car_data/Annotations/2011_001100.xml"]
        # annotation_file = ["../../data/voc_data/Annotations/2008_003374.xml"]
        img_file = [self.img_files[k] for k in indices]
        print(img_file)
        # img_file = ["../../data/car_data/JPEGImages/2011_001100.jpg"]
        # img_file = ["../../data/voc_data/JPEGImages/2008_003374.jpg"]
        imgs, gt_boxes = self._data_generation(annotation_files=annotation_file,
                                               img_files=img_file)
        self.current_batch_index += 1
        print(gt_boxes)
        return imgs, gt_boxes

    def _on_epoch_end(self):
        self.file_indices = np.arange(len(self.annotation_files))
        np.random.shuffle(self.file_indices)
        self.__balance_class_data()

    def _resize_im(self, im, box):
        """ 圖片統一處理到一樣的大小
        :param im:
        :param box:
        :return:
        """
        im_shape = im.shape
        im_size_max = np.max(im_shape[0:2])
        im_scale = float(self.im_size) / float(im_size_max)
        im_resize = cv2.resize(im, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR)

        im_resize_shape = im_resize.shape
        blob = np.zeros((self.im_size, self.im_size, 3), dtype=np.float32)
        blob[0:im_resize_shape[0], 0:im_resize_shape[1], :] = im_resize

        box[:, :4] = box[:, :4] * im_scale
        # print(im_shape, im_resize_shape, np.shape(blob))
        return blob, box, im_scale

    def _data_generation(self, annotation_files, img_files):
        """
        :param annotation_files:
        :param img_files:
        :return:
        """
        gt_boxes = []
        imgs = []
        for i in range(len(annotation_files)):
            img = cv2.imread(img_files[i])
            annotation = xml2dict(annotation_files[i])
            objs = annotation['annotation']['object']
            boxes = []
            if type(objs) == list:
                for box in objs:
                    cls_inds = self.cls_to_inds[box['name']]
                    xmin = int(float(box['bndbox']['xmin']))
                    ymin = int(float(box['bndbox']['ymin']))
                    xmax = int(float(box['bndbox']['xmax']))
                    ymax = int(float(box['bndbox']['ymax']))
                    boxes.append([xmin, ymin, xmax, ymax, cls_inds])
            else:
                cls_inds = self.cls_to_inds[objs['name']]
                xmin = int(float(objs['bndbox']['xmin']))
                ymin = int(float(objs['bndbox']['ymin']))
                xmax = int(float(objs['bndbox']['xmax']))
                ymax = int(float(objs['bndbox']['ymax']))
                boxes.append([xmin, ymin, xmax, ymax, cls_inds])

            if img is not None:
                img, boxes, _ = self._resize_im(img, np.array(boxes, dtype=np.float32))
                imgs.append(img)
                gt_boxes.append(boxes)
        return np.array(imgs, dtype=np.float32) - self.pixel_mean, np.array(gt_boxes, dtype=np.float32)

這裏大致就是從VOC數據集的文件夾中讀取相應的數據,VOC數據集的詳細說明可以參考Tensorflow的圖像操作 人臉圖像數據集

現在我們來看一下訓練的損失函數

for epoch in range(100):
    for batch in range(train_data_generator.total_batch_size):
        print("epcho: {} batch: {}".format(epoch, batch))
        # 獲取每一批次的訓練圖片數據集和標註數據集
        train_imgs, train_gt_boxes = train_data_generator.next_batch()

        with tf.GradientTape() as tape:
            # 訓練圖片以及標註信息通過vgg16主幹網絡,RPN網絡以及ROI Pooling
            # 所獲取的恢復到原始Anchor的Bounding box數量的labels、目標邊框、前景標籤和前背景權重anchor_targets
            # 將非極大值抑制的分類標籤、目標框、前景標籤和前背景權重proposal_targets
            # 將feature map的各個像素所對應的9個預測邊框的前背景分類信息predictions
            # 分類全連接層輸出的分類信息cls_prob
            # 位置座標全連接層輸出的位置信息bbox_pred
            anchor_targets, proposal_targets, predictions, cls_prob, bbox_pred = \
                model([train_imgs, train_gt_boxes], training=True)
            # 定義損失函數
            loss, cross_entropy, loss_box, rpn_cross_entropy, rpn_loss_box = \
                compute_losses(anchor_targets, proposal_targets, predictions)

這裏我們來重點看一下Faster RCNN的損失函數,先看一下RPN網絡的損失函數

它包含了一個類別的損失、一個迴歸的損失。其中Ncls爲256(前背景框總數),使用的是交叉熵損失函數;使用的是一個smooth L1的方法

展開就是

Nreg的值爲所有限制在原圖像範圍內的Bounding box的數量,這裏取的是2400。爲了平衡分類損失和迴歸損失,迴歸損失乘以了一個λ=10,讓兩邊儘量的靠近,不讓一邊太小而學習不到任何的特徵。

Faster RCNN的損失函數跟RPN網絡的損失函數是一模一樣的,它們的區別就是分類損失交叉熵的數值不是前景和背景的分類而是真實的分類(如人,汽車,飛機等)。

這裏我們來看一下compute_losses函數

def compute_losses(anchor_targets, proposal_targets, predictions):

    # 以下class-loss, bbox-loss爲第一次預測損失, 即rpn網絡的預測結果
    # RPN網絡損失函數
    # 獲取所有bounding box分類的概率
    rpn_cls_score = tf.reshape(predictions['rpn_cls_score_reshape'], (-1, 2))
    # 獲取所有bounding box分類標籤
    rpn_label = tf.reshape(anchor_targets['rpn_labels'], (-1,))
    rpn_select = tf.where(rpn_label != -1)
    # 獲取label不爲-1的rpn, 只計算這部分的損失, 這部分不是前景就是背景
    rpn_cls_score = tf.reshape(tf.gather(rpn_cls_score, rpn_select), (-1, 2))
    rpn_label = tf.reshape(tf.gather(rpn_label, rpn_select), (-1,))
    # 這裏修改原實現, 當groud true目標太小, rpn label=1的樣本會很少, 這裏平衡一下0、1樣本
    # 獲取前景的序號
    rpn_fg = tf.where(rpn_label == 1)
    # rpn_nums = tf.shape(rpn_fg)[0]
    rpn_nums = tf.cast(tf.shape(rpn_fg)[0], dtype=tf.float32) * 1.5
    rpn_nums = tf.cast(tf.math.floor(rpn_nums), dtype=tf.int32)
    # 獲取前景標籤
    rpn_fg_label = tf.gather(rpn_label, rpn_fg)
    rpn_bg = tf.random.shuffle(tf.where(rpn_label == 0))[:rpn_nums]
    # 獲取背景標籤
    rpn_bg_label = tf.gather(rpn_label, rpn_bg)
    rpn_idx = tf.concat([rpn_fg, rpn_bg], axis=0)
    rpn_label = tf.concat([rpn_fg_label, rpn_bg_label], axis=0)
    rpn_cls_score = tf.gather(rpn_cls_score, rpn_idx)

    rpn_cross_entropy = 0.
    if tf.shape(rpn_label)[0] > 0:
        # 計算RPN網絡分類損失
        rpn_cross_entropy = tf.reduce_mean(
            losses.SparseCategoricalCrossentropy(from_logits=True)(y_true=rpn_label, y_pred=rpn_cls_score))

    # RPN網絡迴歸損失
    # 獲取偏移量
    rpn_bbox_pred = predictions['rpn_bbox_pred']
    # 獲取目標邊框
    rpn_bbox_targets = anchor_targets['rpn_bbox_targets']
    # 獲取正樣本標記
    rpn_bbox_inside_weights = anchor_targets['rpn_bbox_inside_weights']
    # 獲取正負樣本權重
    rpn_bbox_outside_weights = anchor_targets['rpn_bbox_outside_weights']
    # 計算迴歸損失
    rpn_loss_box = smooth_l1_loss(bbox_pred=rpn_bbox_pred,
                                  bbox_targets=rpn_bbox_targets,
                                  bbox_inside_weights=rpn_bbox_inside_weights,
                                  bbox_outside_weights=rpn_bbox_outside_weights,
                                  sigma=3.0,
                                  dim=[1, 2, 3])

這裏是RPN網絡的兩邊的損失函數(分類和迴歸),這裏我們來看一下smooth_l1_loss函數

def smooth_l1_loss(bbox_pred, bbox_targets, bbox_inside_weights, bbox_outside_weights, sigma, dim):
    """ 計算bbox損失, fast-rcnn論文有詳細說明
    https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Girshick_Fast_R-CNN_ICCV_2015_paper.pdf"""

    sigma_2 = sigma**2
    box_diff = bbox_pred - bbox_targets
    in_box_diff = bbox_inside_weights * box_diff
    abs_in_box_diff = tf.abs(in_box_diff)
    # smooth_l1 = 0.5 * x² if |x| < 1
    # smooth_l1 = |x| - 0.5 if |x| ≥ 1
    smoothL1_sign = tf.cast(tf.less(abs_in_box_diff, 1. / sigma_2), dtype=tf.float32)
    in_loss_box = tf.pow(in_box_diff, 2) * (sigma_2 / 2.) * smoothL1_sign \
                  + (abs_in_box_diff - (0.5 / sigma_2)) * (1. - smoothL1_sign)
    out_loss_box = bbox_outside_weights * in_loss_box
    # 每個batch計算sum損失和, 再對不同batch平均損失
    loss_box = tf.reduce_mean(tf.reduce_sum(out_loss_box, axis=dim))
    return loss_box

然後是Faster RCNN的損失函數

# 以下class-loss, bbox-loss爲第二次預測損失, 即rpn後兩層fc的輸出, 可以看成RCNN的輸出
# RCNN損失函數
# 獲取分類
cls_score = predictions["cls_score"]
# 獲取非極大抑制的分類標籤
label = tf.reshape(proposal_targets["labels"], [-1, ])
# 同樣這裏修改原實現, 也是爲了處理目標太小時樣本的均衡問題
# 獲取前景分類標籤的索引(真實分類非前背景分類)
fg = tf.where(label != 0)
nums = tf.cast(tf.shape(fg)[0], dtype=tf.float32) * 0.5
nums = tf.cast(tf.math.floor(nums), dtype=tf.int32)
fg_label = tf.gather(label, fg)
# 獲取前景一半的背景序號
bg = tf.random.shuffle(tf.where(label == 0))[:nums]
bg_label = tf.gather(label, bg)
idx = tf.concat([fg, bg], axis=0)
label = tf.concat([fg_label, bg_label], axis=0)
# 獲取21種分類的評分
cls_score = tf.gather(cls_score, idx)

cross_entropy = 0.
if tf.shape(label)[0] > 0:
    # 計算RCNN網絡的分類損失
    cross_entropy = tf.reduce_mean(
        tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)(y_true=label, y_pred=cls_score))

# rcnn迴歸損失
# 獲取通過全連接層輸出的座標框
bbox_pred = predictions['bbox_pred']
# 獲取非極大值抑制邊框
bbox_targets = proposal_targets['bbox_targets']
# 獲取非極大值抑制的前景標籤
bbox_inside_weights = proposal_targets['bbox_inside_weights']
# 獲取非極大值抑制的前背景權重
bbox_outside_weights = proposal_targets['bbox_outside_weights']
# 計算RCNN迴歸損失
loss_box = smooth_l1_loss(bbox_pred=bbox_pred,
                          bbox_targets=bbox_targets,
                          bbox_inside_weights=bbox_inside_weights,
                          bbox_outside_weights=bbox_outside_weights,
                          sigma=3.0,
                          dim=1)

# 這裏調整了box損失權重
cross_entropy = cross_entropy * 0.1
rpn_cross_entropy = rpn_cross_entropy * 0.5
loss_box = loss_box
rpn_loss_box = rpn_loss_box
# 計算Faster RCNN總損失
loss = cross_entropy + loss_box + rpn_cross_entropy + rpn_loss_box

return loss, cross_entropy, loss_box, rpn_cross_entropy, rpn_loss_box

回到主方法中進行梯度下降梯度更新

if loss > 0 and cross_entropy > 0 and rpn_cross_entropy > 0:
    # 梯度更新
    grad = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grad, model.trainable_variables))
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章