YOLO系列介紹

YOLOV1

YOLO(You Only Look Once: Unified, Real-Time Object Detection)是Joseph Redmon和Ali Farhadi等於2015年首次提出,在2017年CVPR上,Joseph Redmon和Ali Farhadi又提出的YOLOV2,後又再次提出YOLOV3,它是一個標準的One-stage目標檢測算法。

相對於Faster RCNN系列和SSD系列,它能夠更好的貫徹採用直接回歸的方法獲取到當前需要檢測的目標以及目標類別問題的思想。YOLO算法的核心點在於輸入圖像,採用同時預測多個Bounding box的位置和類別的方法來檢測位置和類別的分類。它是一種更加徹底的端到端的目標檢測識別的方法。相比於Faster RCNN和SSD而言,能夠達到一個更快的檢測速度,但是相對於這兩個系列的算法而言,YOLO算法整體檢測的精度會低一些。

YOLO算法採用直接回歸功能的CNN來完成整個目標檢測的過程。這個過程不需要額外,設計複雜的過程。SSD算法在對目標檢測的過程中,一方面用到了Anchor機制,另一方面需要針對多個不同的feature map來進行預測。Faster RCNN算法需要通過一個RPN網絡來獲取到我們需要的候選Bounding box,再通過後續的預測模型來完成最終的預測結果。YOLO算法相比於這兩種算法而言,沒有Anchor機制,多尺度等等設計的過程。YOLO直接採用一個卷積網絡,最終通過直接回歸的方法,來獲取多個Bounding box的位置以及類別。直接選用整圖來進行模型的訓練並且能夠更好的區分目標和背景區域。

YOLOV1算法的核心思想就在於一方面它通過將圖像分成S*S的格子。對於每一個格子(區域)去負責預測相應的物體。即包含GT物體中心的格子,由它來預測相應的物體。在上圖中有一個帶紅點的格子,它包含了狗的一部分,因此我們就需要用這樣一個格子來檢測狗這種物體。又比如說包含了自行車的格子就用來預測自行車這種物體。

在實際使用的時候,每個格子都會預測B(超參值,通常爲2)個檢測框Bounding box及每個檢測框的置信度(表明了當前的格子所包含物體的概率),對於每個檢測框,要預測C個類別概率(含有待檢測物體的可能性)。對於一個格子,它最終預測出來的向量的長度爲5*B+C,這裏5包含了每一個Bounding box的(x,y,w,h,c),c是置信度,而整個圖所迴歸的向量長度爲(5*B+C)*S*S。

Bounding box信息(x,y,w,h)爲物體的中心位置相對格子位置的偏移及寬度和高度,這些值均被歸一化。換句話說,實際上我們預測的值是一個相對的值。置信度反映是否包含物體以及包含物體情況下位置的準確性,定義爲

其中

YOLOV1的網絡結構圖

YOLO的網絡結構採用直接回歸的方式來獲取到圖片中所需要檢測到的圖片的位置和類別。從上圖中我們可以看到,它只包含了CNN的網絡結構,從最開始的原始的輸入圖像,經過多層卷積之後,最終通過一個FC層,最終輸出的向量爲S*S*(B*5+C)的長度。對於YOLOV1來說,通常S會定義爲7*7,B會定義爲2,C定義爲20,也就是20個類別。通過這樣一個迴歸,我們最終就能夠得到對於每一個格子,它所包含的Bounding box的置信度以及是否包含待檢測物體。如果包含待檢測物體,當前這樣一個物體,它的偏移量是多少以及它的長寬是多少,並且我們能夠得到相應概率的分佈。

通過對FC輸出向量的解碼,解碼之後就能夠得到

這樣兩幅圖,對於第一幅圖雖然每個格子預測了很多Bounding box,但是我們只選擇IoU最高的Bounding box作爲物體檢測的輸出。也就是說對於每個格子,我們最終只預測當中的一個物體,實際上這也是YOLO算法的一個缺陷。圖像中可能會包含多個小的目標,此時由於每個格子只預測一個物體,如果一個格子同時出現多個物體的時候,對於小目標預測效果就會變得非常差。這也是YOLOV1算法主要的一個缺陷。通過NMS來進行Bounding box的合併、篩選和過濾之後,我們就能得到最終的檢測結果。

YOLO算法強調網絡使用小卷積,即:1*1和3*3(GoogleNet),能夠一方面減少計算量,另一方面減少模型的大小。網絡相比VGG16而言,速度會更快,但準確度稍差。

  • YOLOV1損失函數

包含了三種Loss,座標誤差 、IoU誤差和分類誤差。這裏每一種Loss也對應到了每一個網格所預測的信息對應的三種Loss。其中座標誤差對應到了S*S*(B*5+C)中的B,也就是Bounding box預測的信息之間的偏差。IoU的Loss對應到了座標的誤差。分類誤差對應到了當前的格子所包含的groud true(gt,物體類別)所產生的誤差。

通過對這三種誤差的結合,最終通過加權方式進行權重考量來得到最終的loss,通過均方和誤差的方式來進行最終的考量用於後續的網絡模型的訓練。

  • YOLOV1網絡訓練

對於YOLOV1的具體使用的時候會用到下面一些技巧

首先在網絡進行目標檢測的時候,會採用預訓練模型來對模型的參數進行初步的訓練,對模型參數進行初始化。這裏採用ImageNet 1000類來對模型進行預訓練。對於分類任務和迴歸任務而言,會存在最後幾重的差異。對於分類任務ImageNet 1000類的FC層的輸出應該是1000,而YOLOV1的FC層最終輸出爲S*S*(B*5+C)這樣一個值,因此我們在使用預訓練模型的時候會去掉後面的幾個FC層。

這裏實際上採用了預訓練模型前20個卷積層,並且用這前20個卷積層來初始化YOLO,用於後續的目標檢測任務的訓練,如VOC20數據集。由於ImageNet數據的輸入圖像爲224*224,在YOLOV1中會將圖像resize到448*448。對於預訓練模型,如果我們僅僅是使用了它的卷積層,而卷積層對於feature map的大小實際上是不敏感的,它僅僅關注卷積核的參數(大小和通道數)。但是如果我們複用了FC層,FC層的參數量和我們輸入的圖像或者feature map的大小是相關的,如果圖像的大小發生了變化,會影響到FC層的輸入,此時FC層就沒辦法採用預訓練模型來進行訓練了。這裏由於我們在YOLO預訓練的時候只採用了前20個卷積層,去掉了FC層,此時就可以改變圖像的大小,並且能夠保證預訓練模型能繼續使用。

在訓練B個Bounding box的時候,它的GT(真值)的設置是相同的。

  • YOLOV1網絡存在的問題

相對於SSD算法和Faster RCNN算法的效果有一定的差距,

  1. 在進行YOLO最後檢測的時候,輸入尺寸固定,沒有采用多尺度的特徵的輸入。這是相對SSD算法對6個尺度來進行Prio box的提取以及最終的預測。而YOLO算法是一個完整的卷積網絡,沒有提取多尺度的feature map。因此YOLOV1算法在特徵提取的時候通過多個下采樣層學到的最終物體的特徵並不精細,因此也會影響到檢測的效果。
  2. YOLOV1在進行小目標檢測的時候效果差。在同一個格子中包含多個目標時,僅預測一個目標(IoU最高),會忽略掉其他目標,此時就必然會有漏檢的情況產生。
  3. 在YOLOV1的損失函數中關於IoU的loss,實際上並沒有去區分大物體的IoU和小物體IoU的誤差對於網絡訓練loss貢獻值的影響。這裏它們的貢獻值基本上是接近的,實際上對於小物體而言,小物體的IoU的誤差會對網絡優化造成更大的影響,進而降低物體檢測定位的準確性。因此YOLOV1算法在loss設計上也沒有特別多的技巧,這也是後續YOLO算法的改進點。
  4. 如果同一個物體出現新的不常見的長寬比和一些其他情況的時候,YOLOV1算法的泛化能力也較差。
  • YOLOV1網絡性能

上圖是不同尺度訓練的精度與其他網絡的精度對比,我們不難發現YOLOV1在相同的數據集中,他的mAP(精度)下降了很多。但是在檢測速度上,如果只考慮相同尺度的條件下(448*448),YOLO算法能夠達到45FPS,相對於Faster RCNN而言,檢測速度是非常快的。相比於SSD500(即圖像尺寸500*500)的速度也是非常快的,相比於SSD300(圖像尺寸300*300)的速度是非常接近的。也就是說YOLOV1在較大尺寸上的圖像檢測速度能夠保持跟SSD較小圖像檢測速度相同的檢測速度。

YOLOV2

基於YOLOV1存在的問題,作者在2017年提出了YOLOV2的算法,並基於YOLOV2提出了YOLO9000這樣兩種模型。YOLOV2相對於YOLOV1改進的幾個核心的點在於

  1. 引入了Anchor box的思想,改進直接回歸這樣一種比較粗糙的方式
  2. 在輸出層使用卷積層替代YOLOV1的全連接層(FC層),能夠帶來一個比較直觀的好處就是能夠減少對於輸入圖像尺寸的敏感程度。因爲FC層的參數量同圖像大小是息息相關的,而卷積層同圖像大小是不存在關聯的。
  3. 對於YOLO9000而言,在最終訓練的時候,實際上是採用了ImageNet物體分類以及coco物體檢測這樣的兩種數據集來對模型進行訓練。用檢測中的數據集中的數據來學習物體的準確的位置信息。用分類數據集來學習分類的信息。通過這種多任務來提高最終網絡的魯棒性。
  4. 相比於YOLOV1而言,YOLOV2不僅在識別物體的種類上,以及精度、速度、和物體的定位上都得到了大大的提升。

YOLOV2算法成爲了當時最具有代表性的目標檢測算法的一種,YOLOV2/YOLO9000的改進之處:

在上圖中,我們可以看到主幹網絡採用了DarkNet的網絡結構,在YOLOV1算法中,作者採用了GoogleNet這樣一種架構來作爲主幹網絡,它的性能要優於VGGNet的。DarkNet類似於VGGNet,採用了小的3*3的卷積核,在每次池化之後,整個通道的數量都會增加一倍,並且在網絡結構中採用Batch Normalization來進行歸一化處理,進而讓整個訓練過程變得更加的穩定,收斂速度變得更快,達到模型規範化的效果。

由於使用卷積層來代替FC層,因此輸入的圖像尺寸就可以發生變化,因而整個網絡的參數同feature map的大小是無關的。因此我們可以改變圖像的尺寸來進行多尺度的訓練。對於分類模型採用了高分辨率的分類器。

YOLOV1算法只採用了一個維度上的特徵,因此它學到的特徵因此相對來說不會太精細,而YOLOV2採用了一個跳連的結構,換句話說在最終的預測的階段,實際上採用了不同粒度上的特徵,通過對不同粒度上特徵的融合,來提高最終檢測的性能。在最終預測的時候同樣採用了Anchor的機制,Anchor機制也是Faster RCNN或者SSD算法一個非常核心重要的元素,這個元素能夠帶來模型在性能上的提升。

  • Batch Normalization
  1. V1中也大量用了BN,但是在定位層FC層採用了dropout來防止過擬合。
  2. V2中取消了dropout,在整個網絡結構中均採用BN來進行模型的規範化,模型更加穩定,收斂速度更快,
  • 高分辨率分類器
  1. V1中使用224*224的預訓練模型,但是實際上採用了448*448的圖像來用於網絡檢測。這個過程實際上會存在一定的偏差,必然帶來分佈上的差異,
  2. V2直接採用448*448的分辨率微調最初的分類網絡。保證了分類和檢測這樣的兩個模型在分佈上的一致性。
  • Anchor Boxes
  1. 在預測Bounding box的偏移,使用卷積代替FC。我們知道在V1中FC層輸出的向量的大小爲S*S*(B*5+C),而V2中直接採用卷積來代替的話,卷積之後的feature map的大小爲S*S,(B*5+C)則對應了通道的數量,此時同樣能夠達到V1的FC層相同的效果。
  2. 在V2中輸入的圖像尺寸爲416*416,而不是448*448,主要原因就在於圖片中的物體傾向於出現在圖片的中心位置,特別是比較大的物體,因此就需要有一個單獨位於物體中心的位置用來預測這個物體。而YOLO通過卷積層之後,會進行32倍的下采樣。對於416*416的圖像,下采樣32倍之後就會得到一個13*13的feature map。對於448*448的圖像進行32倍下采樣之後得到一個14*14的feature map,此時就不存在這樣一個網格位於圖像的正中心。爲了保證feature map必然會存在一個網格位於圖像的正中心,此時我們只需要將經過下采樣之後的feature map的大小定義爲13*13,就能保證一定會存在中心的一個Cell,能夠預測位於中心的物體。因此我們要保證最終得到的feature map的大小爲13*13,反推過來,進行32倍的上採樣,就能夠得到輸入的尺寸爲416。這也是爲了後面產生的卷積圖的寬高比爲奇數,就能產生一箇中心的Cell。主要原因是作者通過觀察發現大物體通常佔據圖像的中間位置,就需要一個位於中間的Cell來預測位於中間的物體的位置。如果不採用奇數的長寬的話,就需要用到中間的4個Cell來預測中間的物體。通過奇數輸出的技巧就能夠提高總體的效率。
  3. V2加入了Anchor機制之後,對於每一個Cell,會預測多個建議框。相比於之前的網絡僅僅預測B個(B通常爲2)建議框而言,採用Anchor Box之後,結果的召回率得到了顯著的提升。但是mAP卻有了一點點的下降。在作者看來準確率只有小幅度的下降,但是召回率提高了非常大,這也反映了Anchor Box確實能夠在一定程度上帶來整個模型性能上的提升。當然我們也需要去進一步優化準確度下降的缺陷,在V2中採用了max pooling的方法來進行下采樣。
  4. 加入了Anchor機制之後,整個Bounding box的預測數量超過了1000個。比如說經過下采樣之後的feature map爲13*13的話,每個Anchor需要預測9個Bounding box的話,那麼整個feature map需要預測13*13*9=1521個Bounding box。相比於之前的7*7*2=98個而言,整體需要預測的框的數量就得到了提高。進而也會帶來模型在性能上的提高。但是作者在使用Anchor Box之後也遇到了兩個問題,一個是對於Anchor Box而言,它的寬高維度往往是精選先驗框,雖然在訓練的過程中網絡也會調整Box的寬高維度,最終得到準確的Bounding box的位置,但是作者希望在最開始選擇的時候就選擇那些最具代表性的先驗的Bounding box的維度,這樣就能夠通過網絡更容易的學習到準確的預測的位置。因此作者採用了K-means方法來對Bounding box進行迴歸,自動找到那些更好的Bounding box的寬高維度比。在使用K-means方法來對Bounding box聚類的時候,同樣作者也遇到了一個問題,就是傳統的K-means方法來進行Bounding box聚類的時候,主要採用歐式距離的方法來度量兩個Bounding box之間的相似度,這樣也就意味着較大的Bounding box會比較小的Bounding box產生更多的誤差,因此作者在訓練的時候就採用了IoU得分作爲距離的度量。此時所得到的誤差也就和Bounding box的尺寸無關了。經過聚類之後,作者也確定了預測的Anchor Box數量爲5的時候比較合適。作者最終選擇了5種大小的Bounding box的維度來進行定位預測。在這樣一個聚類的結果中,作者也發現扁長的框較少,而瘦高的框會較多一點,這實際上也符合了行人的特徵。有關K-means聚類的內容可以參考聚類
  • 細粒度特徵

在傳統的V1算法中,直接採用了從頂到下的神經網的結構,沒有考慮不同尺度下的特徵。在V2中通過添加pass through layer,把淺層特徵圖(26*26)連接到深層特徵圖(13*13)。在連接的時候作者並沒有採用Pooling的方法進行下采樣,而是將26*26*512的特徵圖直接疊加成13*13*2048的特徵圖,然後再與深層特徵圖相連接,增加細粒度特徵。將粗粒度與細粒度的融合,性能獲得了1%的提升。這是類似於ResNet中的identity mapping的方法。

  • Multi-Scale Training

多尺度訓練,每隔幾次迭代後就會微調網絡的輸入尺寸。輸入圖像尺寸包括了多個不同的尺度{320,352,...,608},這裏爲什麼針對不同的圖像輸入尺度採用同一種參數,主要原因就在於在整個V2結構中並沒有採用FC層這種同feature map大小相關的網絡層,整個網絡通過卷積層的堆疊來完成,因此整個網絡參數的數量同feature map的大小是不相關的。因此我們可以改變圖像的尺寸來增加整個網絡對於圖像尺寸變化的魯棒性。通過這樣的機制就使得網絡可以預測不同尺度的圖片。這也意味着同一個網絡可以進行不同分辨率的檢測任務。在小尺寸圖片上,V2能夠實現更快的運行檢測速度,並且在速度和精度上達到一個平衡。實際上如果輸入的圖像爲228*228的話,它的FPS能夠達到90,並且mAP值同Faster RCNN在同一個水準的。因此V2通常會用在低性能的GPU,高幀率的視頻檢測,多路視頻場景中,也就是說在一些低功耗和視頻圖像處理中,YOLO算法會有更大的應用的範圍,因爲它的速度能夠達到更高的實時性並且在精度上能夠同一些其他的深度學習檢測算法保持在相同的水準上。

  • Darknet-19

在V1中使用GoogleNet作爲主幹網絡,在V2中作者重新設計了一種新的網絡來作爲特徵提取部分。Darknet這種網絡結構,作者也參考了一些前人的先進經驗,它的整個網絡結構也類似於VGGNet,作者使用了較多的3*3的卷積核來進行堆疊。在一次Max Pooling操作後,通道數量進行了翻倍。另外作者借鑑了Net in Net的方法,使用了global average pooling,將1*1的卷積核置於3*3的卷積核之間,用來壓縮特徵。並且在網絡結構設計的時候,作者採用了batch normalization來對模型進行規範化。通過batch normalization一方面加快了模型的訓練速度,另一方面提高了模型訓練的穩定性。整個Darknet的網絡結構包括了19個卷積層以及5個池化層,整個運算的次數爲55.8億次。相比於VGGNet而言,它的整個計算量也有了一定的下降。在ImageNet圖像分類任務上,在top 1準確度上也達到了72.9%的準確度。當然我們在使用YOLOV2的時候,同樣也可以採用更加先進的網絡結構,如ResNet、DenseNet等等一些其他的主幹網絡結構。或者說更加輕量級的網絡結構,如MobileNet等。具體採用什麼網絡結構可以經過不斷的嘗試來對比不同的網絡結構對於YOLOV2算法性能的影響。

對於YOLOV2而言,我們在預測Anchor box,對於每一個Bounding box,同樣會預測4個座標值,1個置信度值和C個類別上的概率分佈值。這一點也是同V1存在區別的,對於V1而言,這裏的類別值是針對於一個Cell(格子)而言的,每一個格子又對應了B個Bounding box,最終預測出來的向量爲(5*B+C)。而在V2中,類別的概率分佈實際上是對於每一個Bounding box而言的,這一點也是同Anchor box保持一致的,對於每一個Bounding box會預測出(5+C)長度的向量,整個Anchor box假設是B個Bounding box,那整個Anchor box預測出來的向量就是B*(5+C)個。這是跟V1相區別的一點,在類別的預測上更加關注於每一個Bounding box,主要原因就在於這裏採用了Anchor機制。而LOYOV1的類別主要是針對於每一個Cell,也就是說對於每一個Cell,只預測一個類別的物體。

  • YOLOV2算法網絡性能

通過YOLOV2對比YOLOV1幾點的改進上來看,我們會發現作者在進行改進的時候,每一點的的加入都會帶來性能上的提升。但是有一點下降的時候就是如上圖所示的加入了Anchor Box的時候性能有了一點點的下降,從69.5降低到了69.2。但是這一點點下降帶來的是召回率的較大程度上的提升。經過了後面跳連、多尺度的加入後,YOLOV2在整體上相對於V1有了一個非常大的提升,從63.4提升到了78.6。

這裏我們可以看一下,相比於SSD和Faster RCNN算法而言,YOLOV2算法能夠達到一個更好的檢測精度,並且能夠實現更快的檢測速度,因此YOLOV2也成爲了當時最先進的深度學習目標檢測算法。

同樣我們也可以看到,上圖是關於mAP和FPS整體的一個曲線圖,YOLOV2它能達到更好的一個效果,在保證較快的檢測速度的同時,能夠保證較好的檢測精度。

YOLO9000

YOLO9000是在YOLOV2的基礎上提出的一種可以檢測超過9000個類別的模型,其主要貢獻點在於提出了一種分類和檢測的聯合訓練策略。

這主要歸功於它採用了WordTree這樣一種結構。通過WordTree來混合檢測數據集與識別數據集中的數據,來達到檢測和分類聯合訓練的效果。這種聯合技術分別在ImageNet和COCO數據集上進行訓練。對於分類任務,它的標籤粒度實際上是更細的。換句話說,對於分類任務而言,同樣是狗,對於數據集中的label,它可能就包括了更加細的狗的類別的劃分,比如說包括了哈士奇、金毛等更細粒度的標籤。而對於檢測任務而言,它僅僅是區別貓、狗這樣一種相對來說粗的粒度上的概念。如果將分類和迴歸採用簡單的方法磨合,就會同時存在狗這樣的label和哈士奇這樣的label的情況。而WordTree則是將這兩種label來構建它們之間的粒度關係,將整個分類和檢測任務的數據集來進行融合。在檢測數據集中,我們不僅需要完成物體類別的迴歸,同樣我們需要對物體的類別進行判定;而在分類數據集上,我們需要對物體的類別進行分類,但是物體類別的粒度會更細。通過WordTree就能夠將label之間的層次關係表示出來。在這樣一種結構中,我們採用了一種圖或者叫WordNet來進行表示,通過WordTree來找到標籤與標籤之間的關係以及包含關係。

在具體訓練的時候,如果一副圖片的label是拿到更多的一些label,比如說不僅是狗,同時也是哺乳動物,同時是犬科,也可能是家畜。那這些label就會同時作爲這個圖片的標記,換句話說對於一副圖片就會產生多個標記,標籤之間不需要相互獨立。對於ImageNet分類任務而言,它使用一個大的SoftMax就能夠完成分類任務。而WordTree則需要對同一概念下的同義詞進行SoftMax,這樣做的好處就在於對一些未知的新的物體在進行分類的時候,整體的性能降低是很優雅的。比如看到一個狗的照片,但是不知道它屬於哪種類別的狗,這個時候高置信度預測就是狗,而其他狗的類別同義詞中,比如說哈士奇、金毛等這些詞,它們的置信度就是低的置信度。作者通過這樣的一種方式,將COCO檢測數據集、ImageNet中的分類數據集來進行混合,利用混合之後的數據集,來進行檢測和分類任務的訓練。最終得到了YOLO9000這樣一個性能更加優的分類器和檢測器。YOLO9000能夠完成9000個物體的檢測和分類,並且能夠保證較高的一個實時性。因此我們將YOLO9000稱作YOLOV2更強的版本。

在上圖中,對於ImageNet分類任務而言,我們需要針對每一個類別,通過一個較大的SoftMax來完成分類。而對於WordTree在進行SoftMax的時候,需要考慮label和label之間的關係,考慮這些label和label之間的關係之後,再通過對同一概念下的同義詞進行SoftMax分類來完成最終的分類loss的計算。通過聯合訓練策略,YOLO9000可以快速檢測出超過9000個類別的物體,總體mAP值爲19.7%。

YOLOV3

YOLOV3相比於V1、V2,更多的考慮的是速度和精度的均衡,融合了更多先進的方法,重點解決了小物體檢測的問題。

  • YOLOV3改進策略:

1、首先在主幹網絡上進行了優化,採用了類似ResNet的網絡結構,來提取更多更優的特徵表示。

如上圖所示,採用ResNet網絡結構,能夠獲取到更加好的檢測效果,當然採用更深層的網絡結構會帶來檢測速度上的下降。這也是在速度和精度上的一種平衡。

2、採用了多尺度的預測,類如FPN的結構來提高檢測的精度。

在上圖的右下角我們可以看到,V3分別從不同尺度的feature map上來提取特徵,作爲YOLO檢測的輸入。對於Anchor的設計,同樣採用聚類的方法來獲得最終的長寬比。通過聚類之後得到9個簇(聚類中心),將這9個簇平均分到了3種尺度上,每一種尺度預測3個Bounding box。對於每一種尺度,作者會引入一些卷積層來進一步的提取特徵,之後再輸出Bounding box的信息。對於尺度1而言,作者直接卷積之後直接輸出Bounding box的信息。對於尺度2而言,作者在輸出Bounding box之前,會對尺度1輸出的卷積進行上採樣,然後同卷積2的feature map進行相加,相加之後再輸出到後續的Bounding box的信息。整個feature map尺寸的大小相對尺度1而言,擴大了兩倍。尺度3相對於尺度2而言,同樣也擴大了兩倍。它的輸入同樣也是在尺度2上經過上採樣,來得到的feature map的大小加上原先的feature map的大小,之後再通過卷積輸出最後的Bounding box的信息。整個結構也是類似於FPN的一種結構。

3、採用了更好的分類器(binary cross-entropy loss二值交叉熵)來完成分類任務。

主要原因就在於Softmax在對每一個Bounding box進行分類的時候只能分配一個類別,就是分數最大的那個類別,最終會輸出一個概率分佈,概率分佈最高的那個值作爲當前Bounding box的類別。當前的目標如果存在重疊的目標標籤的時候,Softmax就不適合這種多標籤分類的問題。實際上Softmax可以通過多個logistic分類器替代,且準確度不會下降。

  • YOLOV3網絡性能

通過上圖我們可以看到,對比YOLOV2的網絡結構,V3能夠實現更好的效果,由於上圖中V3採用的是Darknet,相比於其他採用ResNet的結構,性能會有一些下降。

對於YOLOV3本身採用不同的主幹網絡,採用ResNet-152的時候,它的整體性能能夠達到最好的效果。

對於COCO數據集,這裏也給出了一個性能對比,YOLOV3對比於其他的目標識別網絡結構,同樣也能達到一個比較好的性能的優勢。但YOLOV3整體的檢測速度會有所下降,但相比於其他的目標檢測算法,檢測速度依然會更快。

VOLOV3的框架源碼是由Darknet框架完成的,Darknet框架是由C語言和CUDA實現的,對GPU顯存的利用率較高,對第三方的依賴庫較少。容易實現跨平臺接口的移植,能夠較好的應用於Windows或者嵌入式設備中。Darknet也是實現深度網絡很好的一種框架。

現在我們來看一下YOLOV3的代碼結構,這裏依然以Darknet作爲V3的主幹網絡

import tensorflow as tf
from tensorflow.keras import layers, regularizers, models


class DarkNet:
    def __init__(self):
        pass

    def _darknet_conv(self, x, filters, size, strides=1, batch_norm=True):
        if strides == 1:
            padding = 'same'
        else:
            # 對輸入的圖像矩陣上、左各添加一行(列)的0來作爲padding
            x = layers.ZeroPadding2D(((1, 0), (1, 0)))(x)  # top left half-padding
            padding = 'valid'
        x = layers.Conv2D(filters, (size, size),
                          strides=strides,
                          padding=padding,
                          use_bias=not batch_norm,
                          kernel_regularizer=regularizers.l2(0.0005))(x)
        if batch_norm:
            x = layers.BatchNormalization()(x)
            x = layers.LeakyReLU(alpha=0.1)(x)
        return x

    def _darknet_residual(self, x, filters):
        prev = x
        x = self._darknet_conv(x, filters // 2, 1)
        x = self._darknet_conv(x, filters, 3)
        x = layers.Add()([prev, x])
        return x

    def _darknet_block(self, x, filters, blocks):
        x = self._darknet_conv(x, filters, 3, strides=2)
        for _ in range(blocks):
            x = self._darknet_residual(x, filters)
        return x

    def build_darknet(self, x, name=None):
        # x = inputs = tf.keras.layers.Input([None, None, 3])
        x = self._darknet_conv(x, 32, 3)
        # 1/2
        x = self._darknet_block(x, 64, 1)
        # 1/4
        x = self._darknet_block(x, 128, 2)
        # 1/8
        x = x1 = self._darknet_block(x, 256, 8)
        # 1/16
        x = x2 = self._darknet_block(x, 512, 8)
        # 1/32
        x3 = self._darknet_block(x, 1024, 4)
        # return tf.keras.Model(inputs, (x_36, x_61, x), name=name)
        return x1, x2, x3

    def build_darknet_tiny(self, x, name=None):
        # x = inputs = tf.keras.layers.Input([None, None, 3])
        x = self._darknet_conv(x, 16, 3)
        x = layers.MaxPool2D(2, 2, 'same')(x)
        x = self._darknet_conv(x, 32, 3)
        x = layers.MaxPool2D(2, 2, 'same')(x)
        x = self._darknet_conv(x, 64, 3)
        x = layers.MaxPool2D(2, 2, 'same')(x)
        x = self._darknet_conv(x, 128, 3)
        x = layers.MaxPool2D(2, 2, 'same')(x)
        x = x_8 = self._darknet_conv(x, 256, 3)  # skip connection
        x = layers.MaxPool2D(2, 2, 'same')(x)
        x = self._darknet_conv(x, 512, 3)
        x = layers.MaxPool2D(2, 1, 'same')(x)
        x = self._darknet_conv(x, 1024, 3)
        # return tf.keras.Model(inputs, (x_8, x), name=name)
        return x_8, x

if __name__ == '__main__':
    # import os
    # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    darknet = DarkNet()
    x = layers.Input(shape=(500, 600, 3))
    darknet_model = darknet.build_darknet(x)
    model = models.Model(x, darknet_model)
    print(model.summary())

運行結果

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 500, 600, 3) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 500, 600, 32) 864         input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 500, 600, 32) 128         conv2d[0][0]                     
__________________________________________________________________________________________________
leaky_re_lu (LeakyReLU)         (None, 500, 600, 32) 0           batch_normalization[0][0]        
__________________________________________________________________________________________________
zero_padding2d (ZeroPadding2D)  (None, 501, 601, 32) 0           leaky_re_lu[0][0]                
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 250, 300, 64) 18432       zero_padding2d[0][0]             
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 250, 300, 64) 256         conv2d_1[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_1 (LeakyReLU)       (None, 250, 300, 64) 0           batch_normalization_1[0][0]      
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 250, 300, 32) 2048        leaky_re_lu_1[0][0]              
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 250, 300, 32) 128         conv2d_2[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_2 (LeakyReLU)       (None, 250, 300, 32) 0           batch_normalization_2[0][0]      
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 250, 300, 64) 18432       leaky_re_lu_2[0][0]              
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 250, 300, 64) 256         conv2d_3[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_3 (LeakyReLU)       (None, 250, 300, 64) 0           batch_normalization_3[0][0]      
__________________________________________________________________________________________________
add (Add)                       (None, 250, 300, 64) 0           leaky_re_lu_1[0][0]              
                                                                 leaky_re_lu_3[0][0]              
__________________________________________________________________________________________________
zero_padding2d_1 (ZeroPadding2D (None, 251, 301, 64) 0           add[0][0]                        
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 125, 150, 128 73728       zero_padding2d_1[0][0]           
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 125, 150, 128 512         conv2d_4[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_4 (LeakyReLU)       (None, 125, 150, 128 0           batch_normalization_4[0][0]      
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 125, 150, 64) 8192        leaky_re_lu_4[0][0]              
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 125, 150, 64) 256         conv2d_5[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_5 (LeakyReLU)       (None, 125, 150, 64) 0           batch_normalization_5[0][0]      
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 125, 150, 128 73728       leaky_re_lu_5[0][0]              
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 125, 150, 128 512         conv2d_6[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_6 (LeakyReLU)       (None, 125, 150, 128 0           batch_normalization_6[0][0]      
__________________________________________________________________________________________________
add_1 (Add)                     (None, 125, 150, 128 0           leaky_re_lu_4[0][0]              
                                                                 leaky_re_lu_6[0][0]              
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 125, 150, 64) 8192        add_1[0][0]                      
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 125, 150, 64) 256         conv2d_7[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_7 (LeakyReLU)       (None, 125, 150, 64) 0           batch_normalization_7[0][0]      
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 125, 150, 128 73728       leaky_re_lu_7[0][0]              
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 125, 150, 128 512         conv2d_8[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_8 (LeakyReLU)       (None, 125, 150, 128 0           batch_normalization_8[0][0]      
__________________________________________________________________________________________________
add_2 (Add)                     (None, 125, 150, 128 0           add_1[0][0]                      
                                                                 leaky_re_lu_8[0][0]              
__________________________________________________________________________________________________
zero_padding2d_2 (ZeroPadding2D (None, 126, 151, 128 0           add_2[0][0]                      
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 62, 75, 256)  294912      zero_padding2d_2[0][0]           
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 62, 75, 256)  1024        conv2d_9[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_9 (LeakyReLU)       (None, 62, 75, 256)  0           batch_normalization_9[0][0]      
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 62, 75, 128)  32768       leaky_re_lu_9[0][0]              
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 62, 75, 128)  512         conv2d_10[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_10 (LeakyReLU)      (None, 62, 75, 128)  0           batch_normalization_10[0][0]     
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 62, 75, 256)  294912      leaky_re_lu_10[0][0]             
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 62, 75, 256)  1024        conv2d_11[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_11 (LeakyReLU)      (None, 62, 75, 256)  0           batch_normalization_11[0][0]     
__________________________________________________________________________________________________
add_3 (Add)                     (None, 62, 75, 256)  0           leaky_re_lu_9[0][0]              
                                                                 leaky_re_lu_11[0][0]             
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 62, 75, 128)  32768       add_3[0][0]                      
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, 62, 75, 128)  512         conv2d_12[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_12 (LeakyReLU)      (None, 62, 75, 128)  0           batch_normalization_12[0][0]     
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 62, 75, 256)  294912      leaky_re_lu_12[0][0]             
__________________________________________________________________________________________________
batch_normalization_13 (BatchNo (None, 62, 75, 256)  1024        conv2d_13[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_13 (LeakyReLU)      (None, 62, 75, 256)  0           batch_normalization_13[0][0]     
__________________________________________________________________________________________________
add_4 (Add)                     (None, 62, 75, 256)  0           add_3[0][0]                      
                                                                 leaky_re_lu_13[0][0]             
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 62, 75, 128)  32768       add_4[0][0]                      
__________________________________________________________________________________________________
batch_normalization_14 (BatchNo (None, 62, 75, 128)  512         conv2d_14[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_14 (LeakyReLU)      (None, 62, 75, 128)  0           batch_normalization_14[0][0]     
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 62, 75, 256)  294912      leaky_re_lu_14[0][0]             
__________________________________________________________________________________________________
batch_normalization_15 (BatchNo (None, 62, 75, 256)  1024        conv2d_15[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_15 (LeakyReLU)      (None, 62, 75, 256)  0           batch_normalization_15[0][0]     
__________________________________________________________________________________________________
add_5 (Add)                     (None, 62, 75, 256)  0           add_4[0][0]                      
                                                                 leaky_re_lu_15[0][0]             
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 62, 75, 128)  32768       add_5[0][0]                      
__________________________________________________________________________________________________
batch_normalization_16 (BatchNo (None, 62, 75, 128)  512         conv2d_16[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_16 (LeakyReLU)      (None, 62, 75, 128)  0           batch_normalization_16[0][0]     
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 62, 75, 256)  294912      leaky_re_lu_16[0][0]             
__________________________________________________________________________________________________
batch_normalization_17 (BatchNo (None, 62, 75, 256)  1024        conv2d_17[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_17 (LeakyReLU)      (None, 62, 75, 256)  0           batch_normalization_17[0][0]     
__________________________________________________________________________________________________
add_6 (Add)                     (None, 62, 75, 256)  0           add_5[0][0]                      
                                                                 leaky_re_lu_17[0][0]             
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 62, 75, 128)  32768       add_6[0][0]                      
__________________________________________________________________________________________________
batch_normalization_18 (BatchNo (None, 62, 75, 128)  512         conv2d_18[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_18 (LeakyReLU)      (None, 62, 75, 128)  0           batch_normalization_18[0][0]     
__________________________________________________________________________________________________
conv2d_19 (Conv2D)              (None, 62, 75, 256)  294912      leaky_re_lu_18[0][0]             
__________________________________________________________________________________________________
batch_normalization_19 (BatchNo (None, 62, 75, 256)  1024        conv2d_19[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_19 (LeakyReLU)      (None, 62, 75, 256)  0           batch_normalization_19[0][0]     
__________________________________________________________________________________________________
add_7 (Add)                     (None, 62, 75, 256)  0           add_6[0][0]                      
                                                                 leaky_re_lu_19[0][0]             
__________________________________________________________________________________________________
conv2d_20 (Conv2D)              (None, 62, 75, 128)  32768       add_7[0][0]                      
__________________________________________________________________________________________________
batch_normalization_20 (BatchNo (None, 62, 75, 128)  512         conv2d_20[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_20 (LeakyReLU)      (None, 62, 75, 128)  0           batch_normalization_20[0][0]     
__________________________________________________________________________________________________
conv2d_21 (Conv2D)              (None, 62, 75, 256)  294912      leaky_re_lu_20[0][0]             
__________________________________________________________________________________________________
batch_normalization_21 (BatchNo (None, 62, 75, 256)  1024        conv2d_21[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_21 (LeakyReLU)      (None, 62, 75, 256)  0           batch_normalization_21[0][0]     
__________________________________________________________________________________________________
add_8 (Add)                     (None, 62, 75, 256)  0           add_7[0][0]                      
                                                                 leaky_re_lu_21[0][0]             
__________________________________________________________________________________________________
conv2d_22 (Conv2D)              (None, 62, 75, 128)  32768       add_8[0][0]                      
__________________________________________________________________________________________________
batch_normalization_22 (BatchNo (None, 62, 75, 128)  512         conv2d_22[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_22 (LeakyReLU)      (None, 62, 75, 128)  0           batch_normalization_22[0][0]     
__________________________________________________________________________________________________
conv2d_23 (Conv2D)              (None, 62, 75, 256)  294912      leaky_re_lu_22[0][0]             
__________________________________________________________________________________________________
batch_normalization_23 (BatchNo (None, 62, 75, 256)  1024        conv2d_23[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_23 (LeakyReLU)      (None, 62, 75, 256)  0           batch_normalization_23[0][0]     
__________________________________________________________________________________________________
add_9 (Add)                     (None, 62, 75, 256)  0           add_8[0][0]                      
                                                                 leaky_re_lu_23[0][0]             
__________________________________________________________________________________________________
conv2d_24 (Conv2D)              (None, 62, 75, 128)  32768       add_9[0][0]                      
__________________________________________________________________________________________________
batch_normalization_24 (BatchNo (None, 62, 75, 128)  512         conv2d_24[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_24 (LeakyReLU)      (None, 62, 75, 128)  0           batch_normalization_24[0][0]     
__________________________________________________________________________________________________
conv2d_25 (Conv2D)              (None, 62, 75, 256)  294912      leaky_re_lu_24[0][0]             
__________________________________________________________________________________________________
batch_normalization_25 (BatchNo (None, 62, 75, 256)  1024        conv2d_25[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_25 (LeakyReLU)      (None, 62, 75, 256)  0           batch_normalization_25[0][0]     
__________________________________________________________________________________________________
add_10 (Add)                    (None, 62, 75, 256)  0           add_9[0][0]                      
                                                                 leaky_re_lu_25[0][0]             
__________________________________________________________________________________________________
zero_padding2d_3 (ZeroPadding2D (None, 63, 76, 256)  0           add_10[0][0]                     
__________________________________________________________________________________________________
conv2d_26 (Conv2D)              (None, 31, 37, 512)  1179648     zero_padding2d_3[0][0]           
__________________________________________________________________________________________________
batch_normalization_26 (BatchNo (None, 31, 37, 512)  2048        conv2d_26[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_26 (LeakyReLU)      (None, 31, 37, 512)  0           batch_normalization_26[0][0]     
__________________________________________________________________________________________________
conv2d_27 (Conv2D)              (None, 31, 37, 256)  131072      leaky_re_lu_26[0][0]             
__________________________________________________________________________________________________
batch_normalization_27 (BatchNo (None, 31, 37, 256)  1024        conv2d_27[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_27 (LeakyReLU)      (None, 31, 37, 256)  0           batch_normalization_27[0][0]     
__________________________________________________________________________________________________
conv2d_28 (Conv2D)              (None, 31, 37, 512)  1179648     leaky_re_lu_27[0][0]             
__________________________________________________________________________________________________
batch_normalization_28 (BatchNo (None, 31, 37, 512)  2048        conv2d_28[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_28 (LeakyReLU)      (None, 31, 37, 512)  0           batch_normalization_28[0][0]     
__________________________________________________________________________________________________
add_11 (Add)                    (None, 31, 37, 512)  0           leaky_re_lu_26[0][0]             
                                                                 leaky_re_lu_28[0][0]             
__________________________________________________________________________________________________
conv2d_29 (Conv2D)              (None, 31, 37, 256)  131072      add_11[0][0]                     
__________________________________________________________________________________________________
batch_normalization_29 (BatchNo (None, 31, 37, 256)  1024        conv2d_29[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_29 (LeakyReLU)      (None, 31, 37, 256)  0           batch_normalization_29[0][0]     
__________________________________________________________________________________________________
conv2d_30 (Conv2D)              (None, 31, 37, 512)  1179648     leaky_re_lu_29[0][0]             
__________________________________________________________________________________________________
batch_normalization_30 (BatchNo (None, 31, 37, 512)  2048        conv2d_30[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_30 (LeakyReLU)      (None, 31, 37, 512)  0           batch_normalization_30[0][0]     
__________________________________________________________________________________________________
add_12 (Add)                    (None, 31, 37, 512)  0           add_11[0][0]                     
                                                                 leaky_re_lu_30[0][0]             
__________________________________________________________________________________________________
conv2d_31 (Conv2D)              (None, 31, 37, 256)  131072      add_12[0][0]                     
__________________________________________________________________________________________________
batch_normalization_31 (BatchNo (None, 31, 37, 256)  1024        conv2d_31[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_31 (LeakyReLU)      (None, 31, 37, 256)  0           batch_normalization_31[0][0]     
__________________________________________________________________________________________________
conv2d_32 (Conv2D)              (None, 31, 37, 512)  1179648     leaky_re_lu_31[0][0]             
__________________________________________________________________________________________________
batch_normalization_32 (BatchNo (None, 31, 37, 512)  2048        conv2d_32[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_32 (LeakyReLU)      (None, 31, 37, 512)  0           batch_normalization_32[0][0]     
__________________________________________________________________________________________________
add_13 (Add)                    (None, 31, 37, 512)  0           add_12[0][0]                     
                                                                 leaky_re_lu_32[0][0]             
__________________________________________________________________________________________________
conv2d_33 (Conv2D)              (None, 31, 37, 256)  131072      add_13[0][0]                     
__________________________________________________________________________________________________
batch_normalization_33 (BatchNo (None, 31, 37, 256)  1024        conv2d_33[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_33 (LeakyReLU)      (None, 31, 37, 256)  0           batch_normalization_33[0][0]     
__________________________________________________________________________________________________
conv2d_34 (Conv2D)              (None, 31, 37, 512)  1179648     leaky_re_lu_33[0][0]             
__________________________________________________________________________________________________
batch_normalization_34 (BatchNo (None, 31, 37, 512)  2048        conv2d_34[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_34 (LeakyReLU)      (None, 31, 37, 512)  0           batch_normalization_34[0][0]     
__________________________________________________________________________________________________
add_14 (Add)                    (None, 31, 37, 512)  0           add_13[0][0]                     
                                                                 leaky_re_lu_34[0][0]             
__________________________________________________________________________________________________
conv2d_35 (Conv2D)              (None, 31, 37, 256)  131072      add_14[0][0]                     
__________________________________________________________________________________________________
batch_normalization_35 (BatchNo (None, 31, 37, 256)  1024        conv2d_35[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_35 (LeakyReLU)      (None, 31, 37, 256)  0           batch_normalization_35[0][0]     
__________________________________________________________________________________________________
conv2d_36 (Conv2D)              (None, 31, 37, 512)  1179648     leaky_re_lu_35[0][0]             
__________________________________________________________________________________________________
batch_normalization_36 (BatchNo (None, 31, 37, 512)  2048        conv2d_36[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_36 (LeakyReLU)      (None, 31, 37, 512)  0           batch_normalization_36[0][0]     
__________________________________________________________________________________________________
add_15 (Add)                    (None, 31, 37, 512)  0           add_14[0][0]                     
                                                                 leaky_re_lu_36[0][0]             
__________________________________________________________________________________________________
conv2d_37 (Conv2D)              (None, 31, 37, 256)  131072      add_15[0][0]                     
__________________________________________________________________________________________________
batch_normalization_37 (BatchNo (None, 31, 37, 256)  1024        conv2d_37[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_37 (LeakyReLU)      (None, 31, 37, 256)  0           batch_normalization_37[0][0]     
__________________________________________________________________________________________________
conv2d_38 (Conv2D)              (None, 31, 37, 512)  1179648     leaky_re_lu_37[0][0]             
__________________________________________________________________________________________________
batch_normalization_38 (BatchNo (None, 31, 37, 512)  2048        conv2d_38[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_38 (LeakyReLU)      (None, 31, 37, 512)  0           batch_normalization_38[0][0]     
__________________________________________________________________________________________________
add_16 (Add)                    (None, 31, 37, 512)  0           add_15[0][0]                     
                                                                 leaky_re_lu_38[0][0]             
__________________________________________________________________________________________________
conv2d_39 (Conv2D)              (None, 31, 37, 256)  131072      add_16[0][0]                     
__________________________________________________________________________________________________
batch_normalization_39 (BatchNo (None, 31, 37, 256)  1024        conv2d_39[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_39 (LeakyReLU)      (None, 31, 37, 256)  0           batch_normalization_39[0][0]     
__________________________________________________________________________________________________
conv2d_40 (Conv2D)              (None, 31, 37, 512)  1179648     leaky_re_lu_39[0][0]             
__________________________________________________________________________________________________
batch_normalization_40 (BatchNo (None, 31, 37, 512)  2048        conv2d_40[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_40 (LeakyReLU)      (None, 31, 37, 512)  0           batch_normalization_40[0][0]     
__________________________________________________________________________________________________
add_17 (Add)                    (None, 31, 37, 512)  0           add_16[0][0]                     
                                                                 leaky_re_lu_40[0][0]             
__________________________________________________________________________________________________
conv2d_41 (Conv2D)              (None, 31, 37, 256)  131072      add_17[0][0]                     
__________________________________________________________________________________________________
batch_normalization_41 (BatchNo (None, 31, 37, 256)  1024        conv2d_41[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_41 (LeakyReLU)      (None, 31, 37, 256)  0           batch_normalization_41[0][0]     
__________________________________________________________________________________________________
conv2d_42 (Conv2D)              (None, 31, 37, 512)  1179648     leaky_re_lu_41[0][0]             
__________________________________________________________________________________________________
batch_normalization_42 (BatchNo (None, 31, 37, 512)  2048        conv2d_42[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_42 (LeakyReLU)      (None, 31, 37, 512)  0           batch_normalization_42[0][0]     
__________________________________________________________________________________________________
add_18 (Add)                    (None, 31, 37, 512)  0           add_17[0][0]                     
                                                                 leaky_re_lu_42[0][0]             
__________________________________________________________________________________________________
zero_padding2d_4 (ZeroPadding2D (None, 32, 38, 512)  0           add_18[0][0]                     
__________________________________________________________________________________________________
conv2d_43 (Conv2D)              (None, 15, 18, 1024) 4718592     zero_padding2d_4[0][0]           
__________________________________________________________________________________________________
batch_normalization_43 (BatchNo (None, 15, 18, 1024) 4096        conv2d_43[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_43 (LeakyReLU)      (None, 15, 18, 1024) 0           batch_normalization_43[0][0]     
__________________________________________________________________________________________________
conv2d_44 (Conv2D)              (None, 15, 18, 512)  524288      leaky_re_lu_43[0][0]             
__________________________________________________________________________________________________
batch_normalization_44 (BatchNo (None, 15, 18, 512)  2048        conv2d_44[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_44 (LeakyReLU)      (None, 15, 18, 512)  0           batch_normalization_44[0][0]     
__________________________________________________________________________________________________
conv2d_45 (Conv2D)              (None, 15, 18, 1024) 4718592     leaky_re_lu_44[0][0]             
__________________________________________________________________________________________________
batch_normalization_45 (BatchNo (None, 15, 18, 1024) 4096        conv2d_45[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_45 (LeakyReLU)      (None, 15, 18, 1024) 0           batch_normalization_45[0][0]     
__________________________________________________________________________________________________
add_19 (Add)                    (None, 15, 18, 1024) 0           leaky_re_lu_43[0][0]             
                                                                 leaky_re_lu_45[0][0]             
__________________________________________________________________________________________________
conv2d_46 (Conv2D)              (None, 15, 18, 512)  524288      add_19[0][0]                     
__________________________________________________________________________________________________
batch_normalization_46 (BatchNo (None, 15, 18, 512)  2048        conv2d_46[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_46 (LeakyReLU)      (None, 15, 18, 512)  0           batch_normalization_46[0][0]     
__________________________________________________________________________________________________
conv2d_47 (Conv2D)              (None, 15, 18, 1024) 4718592     leaky_re_lu_46[0][0]             
__________________________________________________________________________________________________
batch_normalization_47 (BatchNo (None, 15, 18, 1024) 4096        conv2d_47[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_47 (LeakyReLU)      (None, 15, 18, 1024) 0           batch_normalization_47[0][0]     
__________________________________________________________________________________________________
add_20 (Add)                    (None, 15, 18, 1024) 0           add_19[0][0]                     
                                                                 leaky_re_lu_47[0][0]             
__________________________________________________________________________________________________
conv2d_48 (Conv2D)              (None, 15, 18, 512)  524288      add_20[0][0]                     
__________________________________________________________________________________________________
batch_normalization_48 (BatchNo (None, 15, 18, 512)  2048        conv2d_48[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_48 (LeakyReLU)      (None, 15, 18, 512)  0           batch_normalization_48[0][0]     
__________________________________________________________________________________________________
conv2d_49 (Conv2D)              (None, 15, 18, 1024) 4718592     leaky_re_lu_48[0][0]             
__________________________________________________________________________________________________
batch_normalization_49 (BatchNo (None, 15, 18, 1024) 4096        conv2d_49[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_49 (LeakyReLU)      (None, 15, 18, 1024) 0           batch_normalization_49[0][0]     
__________________________________________________________________________________________________
add_21 (Add)                    (None, 15, 18, 1024) 0           add_20[0][0]                     
                                                                 leaky_re_lu_49[0][0]             
__________________________________________________________________________________________________
conv2d_50 (Conv2D)              (None, 15, 18, 512)  524288      add_21[0][0]                     
__________________________________________________________________________________________________
batch_normalization_50 (BatchNo (None, 15, 18, 512)  2048        conv2d_50[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_50 (LeakyReLU)      (None, 15, 18, 512)  0           batch_normalization_50[0][0]     
__________________________________________________________________________________________________
conv2d_51 (Conv2D)              (None, 15, 18, 1024) 4718592     leaky_re_lu_50[0][0]             
__________________________________________________________________________________________________
batch_normalization_51 (BatchNo (None, 15, 18, 1024) 4096        conv2d_51[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_51 (LeakyReLU)      (None, 15, 18, 1024) 0           batch_normalization_51[0][0]     
__________________________________________________________________________________________________
add_22 (Add)                    (None, 15, 18, 1024) 0           add_21[0][0]                     
                                                                 leaky_re_lu_51[0][0]             
==================================================================================================
Total params: 40,620,640
Trainable params: 40,584,928
Non-trainable params: 35,712
__________________________________________________________________________________________________
None

 現在我們同樣以一張真實圖片放進DarkNet網絡來看一下它的輸出

import tensorflow as tf
from tensorflow.keras import layers, preprocessing, backend, models, optimizers, losses
import numpy as np
from skimage.transform import resize
import cv2
from darknet import DarkNet
from data.generate_coco_data import CoCoDataGenrator

if __name__ == '__main__':

    image_shape = (640, 640, 3)
    inputs = layers.Input(shape=image_shape, name='input_images')
    img = preprocessing.image.load_img("/Users/admin/Documents/6565.jpeg", color_mode='rgb')
    print(np.asarray(img).shape)
    img = preprocessing.image.img_to_array(img, dtype=np.uint8)
    img = resize(img, (640, 640, 3))
    r, g, b = cv2.split(img)
    img_new = cv2.merge((b, g, r))
    cv2.imshow('img', img_new)
    cv2.waitKey()
    img = tf.reshape(img, shape=(1, 640, 640, 3))
    img = tf.cast(img, dtype=tf.float32)
    darknet = DarkNet()
    x1, x2, x3 = darknet.build_darknet(img, "darknet")
    print(x1)
    print(x2)
    print(x3)

運行結果

(500, 600, 3)
tf.Tensor(
[[[[ 0.01300546  0.04580329  0.02618899 ...  0.0024302   0.00778028
     0.02046073]
   [ 0.01523247  0.0433247   0.03101609 ...  0.00859393  0.01517401
     0.04126479]
   [ 0.01569667  0.04655103  0.02909921 ...  0.01153814  0.0103563
     0.04232045]
   ...
   [ 0.01461979  0.05112193  0.03184253 ...  0.01749601  0.01063334
     0.04914174]
   [ 0.01122806  0.04039535  0.03295705 ...  0.01560631  0.00924575
     0.047635  ]
   [ 0.01934911  0.01768233  0.02554618 ...  0.02241347  0.01196023
     0.04969869]]

  [[ 0.01979208  0.04440274  0.03590046 ... -0.00037868  0.02252091
     0.01689843]
   [ 0.0281611   0.04216104  0.03869232 ...  0.00249345  0.0317611
     0.05443078]
   [ 0.02877221  0.0423179   0.04370406 ...  0.01127182  0.02932161
     0.05519786]
   ...
   [ 0.02331033  0.04598416  0.04644613 ...  0.0131951   0.03045727
     0.06477398]
   [ 0.0180282   0.03877962  0.04242241 ...  0.01323145  0.02318294
     0.06416938]
   [ 0.03270163  0.02084929  0.03383677 ...  0.02968623  0.01979706
     0.07166156]]

  [[ 0.02432447  0.050202    0.03617659 ... -0.00233297  0.03178607
     0.01721654]
   [ 0.02453477  0.04319714  0.04343178 ...  0.00594528  0.03310958
     0.06119607]
   [ 0.02629931  0.05030533  0.04756507 ...  0.01565211  0.02879084
     0.06568647]
   ...
   [ 0.02312282  0.04931618  0.05300389 ...  0.02371687  0.02887772
     0.07611271]
   [ 0.01672611  0.04033546  0.05156515 ...  0.02270166  0.02590587
     0.07241585]
   [ 0.0344753   0.01985516  0.0468677  ...  0.03887149  0.02358301
     0.07878023]]

  ...

  [[ 0.01810458  0.03465767  0.0359303  ... -0.00219411  0.02293537
     0.01476098]
   [ 0.0134875   0.02783474  0.0325518  ...  0.00836905  0.01697291
     0.05196021]
   [ 0.01130611  0.03472897  0.04398032 ...  0.01915504  0.01711655
     0.06013133]
   ...
   [ 0.01296154  0.03170425  0.06206837 ...  0.03199976  0.02187147
     0.08614711]
   [ 0.01051291  0.026424    0.05665638 ...  0.02734326  0.01725644
     0.08029547]
   [ 0.02285716  0.01315682  0.05187852 ...  0.04164089  0.02339076
     0.07870194]]

  [[ 0.01582998  0.03424459  0.03260812 ... -0.00325415  0.02369688
     0.01293045]
   [ 0.01067719  0.0237003   0.03294854 ...  0.00414321  0.01935142
     0.04889458]
   [ 0.00755305  0.02681288  0.04651605 ...  0.01095157  0.02234309
     0.05407221]
   ...
   [ 0.00872021  0.02522171  0.06343129 ...  0.02143701  0.02625748
     0.07797993]
   [ 0.00934438  0.02784847  0.05337022 ...  0.01815011  0.0163433
     0.07494247]
   [ 0.02527687  0.01019516  0.0487967  ...  0.03363573  0.02205626
     0.07579175]]

  [[ 0.01823313  0.01269497  0.02201747 ...  0.00057489  0.01901229
     0.01401678]
   [ 0.01871967  0.01256928  0.02453714 ...  0.00147328  0.02255515
     0.04461277]
   [ 0.01581922  0.01575183  0.03556562 ...  0.00610318  0.02761898
     0.04809618]
   ...
   [ 0.02686368  0.01529721  0.04575747 ...  0.01040503  0.03783914
     0.06418921]
   [ 0.02419975  0.01501905  0.04164971 ...  0.01177746  0.03296132
     0.05938878]
   [ 0.02779927  0.01238328  0.04344043 ...  0.02279663  0.02802737
     0.05404975]]]], shape=(1, 80, 80, 256), dtype=float32)
tf.Tensor(
[[[[ 5.57744503e-02  4.37109694e-02  1.97642922e-01 ...  1.28797188e-01
     3.13888304e-02  1.42925873e-01]
   [ 4.39740494e-02  2.06579026e-02  2.41064608e-01 ...  2.13760287e-01
     7.33558387e-02  2.18340456e-01]
   [ 2.09208392e-02  3.09265479e-02  2.46887892e-01 ...  2.80931830e-01
     8.50827098e-02  2.63412178e-01]
   ...
   [ 2.72087976e-02  5.25276959e-02  2.19686821e-01 ...  3.05240124e-01
     8.77861679e-02  2.72402823e-01]
   [ 2.65834890e-02  4.31933478e-02  1.73503757e-01 ...  2.72548974e-01
     8.33168477e-02  2.47264355e-01]
   [-1.60129666e-02  9.67956055e-03  1.04837202e-01 ...  2.23080724e-01
     4.11004983e-02  1.79972097e-01]]

  [[ 9.13270712e-02  4.25641313e-02  2.92546570e-01 ...  1.85710311e-01
     5.48139662e-02  2.48421669e-01]
   [ 6.77959770e-02  2.06972919e-02  4.27425057e-01 ...  3.10112268e-01
     1.27032176e-01  3.22345883e-01]
   [ 2.10890956e-02  2.81567629e-02  4.73780841e-01 ...  3.85221779e-01
     1.43670052e-01  3.64688814e-01]
   ...
   [ 5.77030051e-03  3.50399017e-02  4.49282169e-01 ...  3.97970021e-01
     1.31809086e-01  3.32577288e-01]
   [ 1.65432375e-02  1.57514010e-02  4.09302562e-01 ...  3.83579433e-01
     1.19915843e-01  3.17814291e-01]
   [-6.36246568e-03  5.90018928e-03  2.24725455e-01 ...  3.30318391e-01
     6.04600236e-02  1.92407310e-01]]

  [[ 9.07450020e-02  5.02355471e-02  3.22316527e-01 ...  2.17806578e-01
     7.86794573e-02  2.86677599e-01]
   [ 5.24090864e-02  2.40034275e-02  4.78858292e-01 ...  3.34466815e-01
     1.59046844e-01  3.56041044e-01]
   [-7.65481964e-04  5.54743968e-02  5.47300458e-01 ...  4.32219326e-01
     1.80383742e-01  4.10912067e-01]
   ...
   [-1.89762097e-02  6.06380999e-02  4.94807541e-01 ...  4.25486624e-01
     1.76484093e-01  3.82894129e-01]
   [-1.15267523e-02  4.02785651e-02  4.51946974e-01 ...  4.17101383e-01
     1.62846759e-01  3.76177728e-01]
   [-2.17601825e-02  1.85616724e-02  2.38932043e-01 ...  3.90599698e-01
     5.71805388e-02  2.44674429e-01]]

  ...

  [[ 6.35816827e-02  3.37417796e-02  2.64523119e-01 ...  1.85912058e-01
     7.50792101e-02  2.26906434e-01]
   [ 2.45644636e-02  2.49361265e-02  4.09977913e-01 ...  2.56267220e-01
     1.35049358e-01  2.58147597e-01]
   [-1.11659914e-02  5.32234684e-02  4.76467907e-01 ...  3.11896145e-01
     1.58500791e-01  2.89531261e-01]
   ...
   [-2.83652470e-02  6.99636936e-02  5.44180274e-01 ...  3.79703373e-01
     2.03064114e-01  3.48033041e-01]
   [-3.48832272e-02  5.73238991e-02  4.96190369e-01 ...  3.72090876e-01
     1.81635693e-01  3.58659029e-01]
   [-3.89259271e-02  2.73454078e-02  2.71928340e-01 ...  3.39304954e-01
     8.07696208e-02  2.54779845e-01]]

  [[ 5.60188815e-02  3.48543599e-02  2.20965892e-01 ...  1.62926540e-01
     7.29940161e-02  2.05437750e-01]
   [ 2.22457871e-02  4.96294089e-02  3.31111431e-01 ...  2.18296364e-01
     1.22494869e-01  2.22173691e-01]
   [-3.72293964e-03  8.88322666e-02  3.88356268e-01 ...  2.59815037e-01
     1.47945717e-01  2.44451314e-01]
   ...
   [-2.03779750e-02  1.10053554e-01  4.32410121e-01 ...  3.11534166e-01
     1.87783048e-01  3.03887814e-01]
   [-3.06747276e-02  9.53165665e-02  3.81398976e-01 ...  2.97086895e-01
     1.66954249e-01  2.94213474e-01]
   [-3.05685177e-02  3.21246460e-02  2.30680123e-01 ...  2.59103298e-01
     9.64258835e-02  2.30618730e-01]]

  [[ 4.33812030e-02  3.45876142e-02  1.36073053e-01 ...  7.24842548e-02
     5.62026761e-02  1.23662286e-01]
   [ 2.54868492e-02  5.30315749e-02  2.28306949e-01 ...  9.75592732e-02
     9.43436623e-02  1.28626943e-01]
   [ 1.92903653e-02  7.24871531e-02  2.70146757e-01 ...  1.20144218e-01
     1.10490926e-01  1.47676066e-01]
   ...
   [ 7.02063087e-03  8.36579055e-02  3.48761082e-01 ...  1.58494785e-01
     1.30590826e-01  1.79404482e-01]
   [-1.82694057e-05  5.76314554e-02  3.16185445e-01 ...  1.54782787e-01
     1.05770782e-01  1.70392156e-01]
   [-4.06836253e-03  1.12849586e-02  2.12644458e-01 ...  1.02986082e-01
     6.46547675e-02  9.39983502e-02]]]], shape=(1, 40, 40, 512), dtype=float32)
tf.Tensor(
[[[[ 0.12318148  0.03376219  0.01146538 ...  0.09772492  0.24040186
     0.14421844]
   [ 0.1259669   0.05908597  0.09456294 ...  0.2132321   0.3147785
     0.12912317]
   [ 0.09492658  0.10398786  0.15491882 ...  0.20309469  0.33134025
     0.15457296]
   ...
   [ 0.06671321  0.13512364  0.15603073 ...  0.20307465  0.3218756
     0.1176842 ]
   [ 0.03612547  0.13628691  0.15187964 ...  0.2019774   0.28994283
     0.10808527]
   [-0.0160175   0.14134313  0.18413983 ...  0.18274452  0.13057366
     0.09312407]]

  [[ 0.18831736  0.09769126  0.07227302 ...  0.18843201  0.2376497
     0.23451753]
   [ 0.23235133  0.20166397  0.18566903 ...  0.30471733  0.42872095
     0.23297027]
   [ 0.16126493  0.26441634  0.23911062 ...  0.29786062  0.48642734
     0.26872134]
   ...
   [ 0.11938853  0.28962046  0.258896   ...  0.29687464  0.51437193
     0.19304752]
   [ 0.06673403  0.28723678  0.26476678 ...  0.3100859   0.48218828
     0.1223008 ]
   [-0.02131327  0.25196818  0.40540588 ...  0.3097421   0.2831375
     0.03770523]]

  [[ 0.308563    0.09754747  0.10471658 ...  0.24433848  0.26193607
     0.3069915 ]
   [ 0.325613    0.1744133   0.2388102  ...  0.4038601   0.5247658
     0.2331565 ]
   [ 0.19707479  0.29004464  0.29805616 ...  0.41098198  0.5953533
     0.2674497 ]
   ...
   [ 0.11919326  0.33183312  0.32917818 ...  0.42651626  0.6684428
     0.19919254]
   [ 0.03990063  0.33553725  0.35605377 ...  0.46647537  0.6166377
     0.10448103]
   [-0.07907544  0.32569277  0.5182643  ...  0.41053414  0.36418986
     0.04715421]]

  ...

  [[ 0.38908362  0.11467192  0.10097787 ...  0.32609707  0.04908396
     0.33989137]
   [ 0.36390406  0.08983842  0.25493437 ...  0.45007908  0.46792823
     0.17120835]
   [ 0.27981567  0.22168614  0.31461418 ...  0.5019402   0.5556003
     0.17275102]
   ...
   [ 0.31684056  0.31729382  0.46586263 ...  0.6414348   0.5536462
     0.17497462]
   [ 0.25943795  0.3790891   0.6061166  ...  0.6894225   0.45387262
     0.10104082]
   [ 0.06759107  0.3362249   0.74389505 ...  0.5601219   0.5609029
    -0.09805419]]

  [[ 0.28793362  0.08725777  0.10377577 ...  0.19679242  0.05906167
     0.24279368]
   [ 0.26815853  0.07847413  0.20731884 ...  0.25489905  0.31678712
     0.15527718]
   [ 0.22508654  0.15286478  0.24512848 ...  0.31735316  0.37194827
     0.15467714]
   ...
   [ 0.2053708   0.259059    0.4087341  ...  0.36274028  0.37855875
     0.19536555]
   [ 0.17315325  0.30218142  0.49641716 ...  0.38937995  0.32239732
     0.12586012]
   [ 0.0407083   0.21762022  0.51479626 ...  0.34929705  0.36270672
    -0.04694114]]

  [[ 0.15374078  0.06020061  0.08242871 ...  0.18834715 -0.0276424
     0.1503744 ]
   [ 0.16523457  0.04991119  0.24927962 ...  0.13250573  0.1395475
     0.07700527]
   [ 0.13136236  0.08214448  0.25838062 ...  0.17917573  0.15435402
     0.05991633]
   ...
   [ 0.17718612  0.12624508  0.4046963  ...  0.17965443  0.16851594
     0.05702776]
   [ 0.16144499  0.13742751  0.47926506 ...  0.15510318  0.16570172
     0.02195632]
   [ 0.10069057  0.07071531  0.42017645 ...  0.1173479   0.2154476
    -0.0434314 ]]]], shape=(1, 20, 20, 1024), dtype=float32)

通過結果我們可以看到輸出的x3是一個20*20的1024通道的feature map。通過幾個卷積層的特徵提取後,我們來看一下第一種尺度的確立。

if isinstance(x3, tuple):
    x, x_skip = x3[0], x3[1]

    # concat with skip connection
    x = darknet._darknet_conv(x, 512, 1)
    x = layers.UpSampling2D(2)(x)
    x = layers.Concatenate()([x, x_skip])
else:
    x = x3
# 繼續提取特徵
x = darknet._darknet_conv(x, 512, 1)
x = darknet._darknet_conv(x, 512 * 2, 3)
x = darknet._darknet_conv(x, 512, 1)
x = darknet._darknet_conv(x, 512 * 2, 3)
x = darknet._darknet_conv(x, 512, 1)
# 第一個連接點
concat_output = x

x = darknet._darknet_conv(x, 512 * 2, 3)
# 9個簇,3個尺度
anchor_masks = np.array([[6, 7, 8], [3, 4, 5], [0, 1, 2]])
# 3
num_anchors = len(anchor_masks[0])
# [batch, h, w, num_anchors * (num_class + 5)]
# 此處不使用批歸一化和激活函數,91爲分類的類別數,經過1*1的卷積變爲通道數爲288
x = darknet._darknet_conv(x, num_anchors * (91 + 5), 1, batch_norm=False)
# [batch, h, w, num_anchors, (num_class + 5)]
# 獲取每一個Bounding box的座標偏移,寬高,置信度,91種分類值
x = layers.Lambda(lambda x: tf.reshape(x, (-1, tf.shape(x)[1], tf.shape(x)[2],
                                                    num_anchors, 91 + 5)))(x)
print('x-feature', x)

運行結果

x-feature tf.Tensor(
[[[[[ 1.43855521e-02  3.23734321e-02 -3.67170293e-03 ...
     -5.93810296e-03 -2.72588581e-02 -3.25208111e-03]
    [-3.53677664e-03  6.76192809e-03 -4.44413349e-03 ...
      2.82656588e-03 -7.69695640e-03  3.29773836e-02]
    [ 8.74534249e-04  3.14397179e-02  2.98833847e-02 ...
     -6.54635532e-03 -1.05559025e-02  8.90140049e-03]]

   [[ 4.17503864e-02  4.82537672e-02  6.04504673e-03 ...
      2.36063190e-02 -5.35690263e-02 -1.75136477e-02]
    [-8.13679397e-02 -1.36799887e-02 -1.80149078e-02 ...
     -5.34049422e-03  1.09536406e-02  8.79051313e-02]
    [-1.21820532e-02  3.43158655e-02  4.46718559e-02 ...
      4.26378567e-03  2.83443183e-02 -1.76916551e-02]]

   [[ 4.78246845e-02  7.02402964e-02 -2.98725674e-03 ...
      3.17356437e-02 -5.53366765e-02 -2.45371554e-02]
    [-1.09746695e-01 -1.01776635e-02 -1.15971267e-02 ...
      7.45519437e-03  1.98712982e-02  9.70957130e-02]
    [-1.79717652e-02  4.98642251e-02  6.01803586e-02 ...
      8.30962509e-03  3.16046104e-02 -2.41924915e-02]]

   ...

   [[ 7.06956089e-02  5.01229055e-02 -1.92262903e-02 ...
      2.69248001e-02 -7.39700571e-02 -3.62998098e-02]
    [-1.18457004e-01 -3.30782309e-03 -3.68624963e-02 ...
      6.42429944e-03  3.01392116e-02  1.19885460e-01]
    [-1.75386984e-02  5.80466166e-02  5.58605045e-02 ...
      2.99791712e-02  5.00387549e-02 -3.73844020e-02]]

   [[ 8.40673596e-02  3.32510434e-02 -4.41334471e-02 ...
      3.93143483e-03 -5.68560772e-02 -4.21492085e-02]
    [-9.49152410e-02  5.57698309e-04 -3.83891016e-02 ...
      7.42006395e-03  3.16541083e-02  1.02742508e-01]
    [-3.06121670e-02  4.80219983e-02  4.90897596e-02 ...
      4.73524630e-02  4.97985967e-02 -2.46338211e-02]]

   [[ 7.96302259e-02  1.91449262e-02 -3.81032005e-02 ...
     -1.58824027e-04 -4.31968197e-02 -2.81720422e-02]
    [-1.01859346e-01  1.41646825e-02 -1.52143966e-02 ...
     -5.00067556e-03  2.77400874e-02  7.79438317e-02]
    [-1.60612054e-02  9.17838793e-03  2.12875605e-02 ...
      3.83229628e-02  2.12987885e-02 -1.58014931e-02]]]


  [[[ 2.74948329e-02  5.05619273e-02 -2.29799170e-02 ...
     -7.93745508e-04 -5.51016741e-02  4.98068705e-03]
    [-4.01594229e-02  2.39265691e-02 -2.29697190e-02 ...
      2.10843701e-02  3.60490847e-03  7.17011243e-02]
    [-3.13397646e-02  3.12914401e-02  5.55699021e-02 ...
     -1.86149422e-02 -3.71585526e-02 -1.11365691e-02]]

   [[ 6.28519207e-02  6.23441786e-02 -1.86228883e-02 ...
      3.08187008e-02 -6.98437244e-02 -5.62309194e-03]
    [-1.26209423e-01  1.00298151e-02 -4.34425063e-02 ...
      3.09237503e-02  3.99011839e-03  1.37707442e-01]
    [-5.87570965e-02  6.70248717e-02  9.58192796e-02 ...
     -1.66721328e-03 -8.48327018e-03 -3.57441008e-02]]

   [[ 7.48019442e-02  7.09551796e-02 -4.55776900e-02 ...
      2.30621602e-02 -8.92516002e-02 -1.51880067e-02]
    [-1.76495656e-01  7.75676500e-03 -3.12610716e-02 ...
      3.46352980e-02 -1.50344521e-03  1.68000251e-01]
    [-6.90756738e-02  6.00972995e-02  9.68224928e-02 ...
      8.31374712e-03 -2.22650059e-02 -3.50602530e-02]]

   ...

   [[ 9.00407583e-02  6.28758818e-02 -7.87543431e-02 ...
      2.28788834e-02 -1.14216708e-01 -3.77769247e-02]
    [-1.90228492e-01  1.82287768e-05 -2.58688331e-02 ...
      4.48843502e-02  1.38927093e-02  2.05337286e-01]
    [-5.78798950e-02  1.05381131e-01  1.00754865e-01 ...
      3.34641226e-02  2.75914967e-02 -5.70844561e-02]]

   [[ 1.05826035e-01  4.62486446e-02 -8.39314014e-02 ...
      9.15204361e-03 -1.13347352e-01 -5.53399697e-02]
    [-1.50426179e-01 -3.87058128e-03 -3.00489590e-02 ...
      2.81358380e-02  1.88718829e-02  1.78275421e-01]
    [-5.05372919e-02  1.10970370e-01  7.74718523e-02 ...
      4.80326340e-02  2.88719907e-02 -2.23084521e-02]]

   [[ 9.67991948e-02  2.99865901e-02 -5.90101480e-02 ...
      2.17598621e-02 -7.23858923e-02 -5.25329933e-02]
    [-1.40194416e-01  3.17972936e-02 -1.44932307e-02 ...
      7.22579099e-03  3.50863002e-02  1.28113016e-01]
    [-3.52293998e-02  2.98271589e-02  4.73817810e-02 ...
      4.75752540e-02  3.86793353e-03  1.14865787e-03]]]


  [[[ 4.17320058e-02  7.47939944e-02 -3.67180519e-02 ...
      8.78189108e-04 -7.91597962e-02 -7.43087381e-03]
    [-3.72411422e-02  2.48172469e-02 -8.54682736e-03 ...
      3.10143922e-02  8.72607157e-03  9.71082598e-02]
    [-4.64725122e-02  3.81718948e-02  7.09193796e-02 ...
     -2.27786489e-02 -5.76945543e-02 -2.00277753e-02]]

   [[ 8.94888863e-02  9.55630019e-02 -2.68667266e-02 ...
      1.83212534e-02 -8.09980258e-02 -2.87768878e-02]
    [-1.29027143e-01  2.92052981e-03 -2.65594479e-02 ...
      5.60834520e-02 -7.63696432e-03  1.87813476e-01]
    [-8.47262442e-02  1.02939576e-01  9.95508805e-02 ...
     -5.55549329e-03 -4.15047444e-03 -4.96640876e-02]]

   [[ 1.16421722e-01  8.58199894e-02 -5.58324568e-02 ...
      7.33696949e-03 -1.09642014e-01 -4.48607728e-02]
    [-1.84635162e-01  1.18100038e-02 -8.70808028e-04 ...
      6.05232753e-02 -8.16491246e-03  2.26934075e-01]
    [-9.52242613e-02  9.61375087e-02  1.16436489e-01 ...
      1.51424110e-02 -6.67663850e-03 -4.60440591e-02]]

   ...

   [[ 1.53658032e-01  1.08256996e-01 -1.26264885e-01 ...
      1.84526388e-02 -1.51171982e-01 -7.67052174e-02]
    [-2.43474692e-01  1.49110164e-02 -1.30938403e-02 ...
      8.31634849e-02 -3.53191309e-02  2.80060172e-01]
    [-8.59394073e-02  1.34953350e-01  1.24888279e-01 ...
      5.64943552e-02  4.94653843e-02 -6.03685342e-02]]

   [[ 1.44291222e-01  8.66962001e-02 -1.22786529e-01 ...
      7.17543438e-03 -1.52220428e-01 -9.69632566e-02]
    [-2.01931000e-01 -1.39550492e-03 -2.16482691e-02 ...
      3.62620205e-02 -2.45773830e-02  2.31392995e-01]
    [-7.56062120e-02  1.36760980e-01  1.06802300e-01 ...
      6.82064295e-02  4.44108844e-02 -3.16042416e-02]]

   [[ 1.33012280e-01  4.28547040e-02 -9.10417140e-02 ...
      2.12642662e-02 -1.08710676e-01 -6.61423728e-02]
    [-1.81244701e-01  4.85137552e-02 -2.70637348e-02 ...
     -1.13902893e-02  1.35496520e-02  1.63882941e-01]
    [-5.21412939e-02  4.90978062e-02  6.27045110e-02 ...
      8.06145146e-02  1.60848722e-02  8.45119730e-03]]]


  ...


  [[[ 1.08628646e-02  1.11254916e-01 -8.07504803e-02 ...
      1.21952817e-02 -1.03620023e-01 -2.20873058e-02]
    [-6.67699724e-02  5.96636832e-02  5.51500320e-02 ...
      7.00847656e-02 -1.07938163e-02  8.37760791e-02]
    [-5.77524193e-02  5.51094189e-02  6.49683774e-02 ...
     -4.26605567e-02 -8.21135566e-03 -1.28405597e-02]]

   [[ 5.81651032e-02  1.49165615e-01 -1.81277283e-02 ...
      3.06064263e-02 -1.37036607e-01 -7.50004202e-02]
    [-1.37735859e-01  8.04366693e-02  7.41885602e-02 ...
      6.55304939e-02 -2.43935101e-02  1.74660489e-01]
    [-5.84301054e-02  7.85881653e-02  2.31113024e-02 ...
     -1.54732708e-02 -5.13509251e-02  9.48662870e-04]]

   [[ 1.06464945e-01  1.24650642e-01 -3.61565799e-02 ...
      1.52456388e-03 -1.68448240e-01 -1.38836905e-01]
    [-2.12003008e-01  5.89370839e-02  5.97102344e-02 ...
      5.59944659e-02 -2.49991696e-02  2.08926409e-01]
    [-5.65537736e-02  1.12186335e-01  3.25093418e-02 ...
      3.28003578e-02 -8.32641423e-02  1.00097563e-02]]

   ...

   [[ 1.51114792e-01  1.58362105e-01 -6.47453591e-02 ...
      5.14809787e-03 -2.83097416e-01 -2.57758439e-01]
    [-3.13894004e-01  7.31603056e-02  6.23276010e-02 ...
      7.82078877e-02 -4.45402861e-02  2.94473231e-01]
    [-8.32889974e-02  1.22369662e-01  8.58929753e-02 ...
      7.54653439e-02 -8.54809880e-02 -1.32431313e-02]]

   [[ 1.33512348e-01  1.26630753e-01 -4.58583161e-02 ...
     -1.28193274e-02 -2.54865557e-01 -2.46457934e-01]
    [-2.78497875e-01  8.39896873e-02 -2.95564532e-04 ...
     -8.82694125e-03 -6.67940453e-02  2.41247982e-01]
    [-6.61531016e-02  1.37382865e-01  7.88134113e-02 ...
      1.80205762e-01 -8.17115009e-02  2.97585502e-02]]

   [[ 1.29782706e-01  2.50800662e-02 -6.52197376e-02 ...
     -2.07002461e-02 -1.56805068e-01 -1.45573348e-01]
    [-2.44694024e-01  1.00220524e-01 -1.42512023e-01 ...
     -2.84452643e-02 -1.67766735e-02  1.54159695e-01]
    [-3.97469252e-02  3.13467085e-02  3.77028584e-02 ...
      1.52708039e-01  1.33550316e-02  3.96765023e-03]]]


  [[[ 3.32733244e-03  8.66670758e-02 -6.07146546e-02 ...
      1.44701917e-02 -7.41072074e-02 -1.12504112e-02]
    [-5.52796051e-02  6.39907420e-02  5.91539070e-02 ...
      4.56419885e-02 -1.43662700e-02  6.01993538e-02]
    [-4.13223058e-02  4.23230007e-02  5.12221456e-02 ...
     -2.44814064e-02  7.68753141e-03 -1.34938741e-02]]

   [[ 3.15343663e-02  1.12439543e-01 -1.23838335e-02 ...
      3.96174788e-02 -8.55981112e-02 -7.04407394e-02]
    [-1.02561615e-01  7.36098662e-02  6.62841946e-02 ...
      5.17353155e-02 -3.26524451e-02  1.02890044e-01]
    [-3.34513411e-02  4.52073961e-02  1.34149753e-02 ...
     -3.91144399e-03 -4.60506305e-02  1.18983788e-02]]

   [[ 4.87772003e-02  9.48870182e-02 -2.31062546e-02 ...
      1.21587627e-02 -1.22057237e-01 -1.12674981e-01]
    [-1.46121517e-01  4.62991670e-02  6.22609779e-02 ...
      5.49749732e-02 -2.72053480e-02  1.34005874e-01]
    [-3.90926264e-02  6.88402504e-02  2.43551135e-02 ...
      2.40953639e-02 -7.21580312e-02  6.27383497e-03]]

   ...

   [[ 6.31585419e-02  1.37103528e-01 -5.87803200e-02 ...
      4.47873026e-03 -2.04616517e-01 -1.89852357e-01]
    [-2.06350535e-01  7.93156251e-02  4.91615236e-02 ...
      2.46029869e-02 -3.90526429e-02  1.89593405e-01]
    [-4.61713150e-02  8.22851062e-02  6.44003153e-02 ...
      8.36576372e-02 -7.58771226e-02 -6.70336187e-03]]

   [[ 4.83803637e-02  1.05444193e-01 -4.93297055e-02 ...
     -2.84428932e-02 -1.93648577e-01 -1.60674170e-01]
    [-1.77676767e-01  7.94684216e-02  5.94073907e-03 ...
     -1.80556979e-02 -3.90497334e-02  1.71817034e-01]
    [-4.95870933e-02  8.17895085e-02  5.93550652e-02 ...
      1.39275029e-01 -7.31767714e-02  2.99034156e-02]]

   [[ 7.52071366e-02  4.28408384e-07 -4.29156683e-02 ...
     -2.15305649e-02 -1.18705362e-01 -1.08373582e-01]
    [-1.53610006e-01  8.31815824e-02 -1.10872447e-01 ...
     -2.87074856e-02 -4.76869289e-03  1.06467336e-01]
    [-1.55894952e-02  1.17892586e-02  2.64222510e-02 ...
      1.16865277e-01  1.46103539e-02  1.16869062e-03]]]


  [[[ 1.61800720e-03  2.64321100e-02 -3.98671329e-02 ...
      9.84302349e-03 -5.75484447e-02 -6.95041055e-03]
    [-1.87479239e-02  4.90886793e-02  5.64492717e-02 ...
      2.24619880e-02 -1.82566028e-02  2.19828263e-02]
    [-1.61643643e-02  1.01023884e-02  8.82678013e-03 ...
     -3.97317577e-04  6.44671265e-03  6.35148119e-03]]

   [[-5.73368184e-03  3.67868356e-02 -1.72267333e-02 ...
      1.94614641e-02 -6.88518360e-02 -6.84302747e-02]
    [-4.98339571e-02  5.38594611e-02  6.75405487e-02 ...
      6.59726374e-03 -2.86071394e-02  2.94821672e-02]
    [-1.70322154e-02  3.23171122e-03 -9.67009738e-03 ...
      5.85122686e-03 -2.35910937e-02  2.97874026e-02]]

   [[-1.16223255e-02  3.66859548e-02 -7.65427202e-03 ...
      1.64846610e-02 -9.62952226e-02 -9.62916389e-02]
    [-7.86703229e-02  2.86225770e-02  6.18349276e-02 ...
      1.29025616e-03 -2.02125106e-02  3.72486040e-02]
    [-1.20273829e-02  3.47230658e-02 -4.51089814e-03 ...
      2.45419908e-02 -2.62224786e-02  2.21864134e-02]]

   ...

   [[-3.19192931e-02  7.18351603e-02 -4.28153500e-02 ...
      7.03381747e-03 -1.36818856e-01 -1.56896755e-01]
    [-1.05386548e-01  4.15120684e-02  6.81234822e-02 ...
     -3.12198270e-02 -3.42080072e-02  6.01806380e-02]
    [-2.06874460e-02  4.44727913e-02  1.31856017e-02 ...
      5.60805239e-02 -1.72917321e-02 -1.07942522e-03]]

   [[-2.97704488e-02  7.49808848e-02 -3.62125337e-02 ...
     -2.06511952e-02 -1.19673520e-01 -1.37231916e-01]
    [-7.55185559e-02  4.33621854e-02  4.26110551e-02 ...
     -4.26517539e-02 -2.30262168e-02  8.32250267e-02]
    [-6.60321116e-03  3.54940891e-02  2.36980543e-02 ...
      8.46842751e-02 -2.45643649e-02  1.37924962e-02]]

   [[ 5.91525901e-03  2.44261324e-03 -3.34736444e-02 ...
     -2.78865341e-02 -6.24412596e-02 -1.12322122e-01]
    [-8.07825625e-02  3.91773060e-02 -4.91225272e-02 ...
     -3.75587940e-02 -2.48860158e-02  4.39865328e-02]
    [ 7.03450711e-03 -4.43650782e-03  1.44479815e-02 ...
      6.97280839e-02  3.27673666e-02 -1.41870379e-02]]]]], shape=(1, 20, 20, 3, 96), dtype=float32)

通過結果,我們可以看出,該tensor代表的是一個feature map,寬15,高18,3個尺寸的anchors,每一個anchor都有91種分類,座標值4個和一個置信度。

# anchorbox的形狀
anchors = np.array([[17, 20], [43, 52], [66, 127], [132, 69], [116, 243], [205, 149],
                    [233, 363], [410, 216], [496, 440]], np.float32) / image_shape[0]
first_out_bbox, first_out_objectness, first_out_class_probs, first_out_pred_box = layers.Lambda(
    lambda x: yolo_boxes(x, anchors[anchor_masks[0]], 91),
    name='yolo_boxes_first_out')(x)

 我們對每一個tensor都進行運算,獲取每一個anchor的座標框、置信度、分類概率、座標偏移量,這裏跟Faster RCNN的Anchor機制不同的是,Faster RCNN是直接在原圖像區域內獲取feature map每一個像素映射的區域的中心點Anchor,再分出9種不同的形狀來獲取Bounding box;而YOLO是在feature map中劃分單元格,來預測每一個單元格中是否包含目標區域。YOLO的單元格有3種劃分的方法,每一種單元格的劃分各有一箇中心點。

def yolo_boxes(pred, anchors, num_classes):
    """ 最後的預測結果
    """
    # pred: (batch_size, grid, grid, anchors, (x, y, w, h, obj, ...classes))
    # 獲取feature map的尺寸, 尺寸越大,能檢測的目標越小,尺寸越小,能檢測的目標越大
    grid_size = tf.shape(pred)[1:3]
    print('grid_size', grid_size)
    # 對tensor進行2:2:1:91的切片,前兩個是anchorbox左上的座標值偏移量,再兩個是anchorbox的寬高,
    # 再一個是置信度,後面的都是分類特徵值
    box_xy, box_wh, objectness, class_probs = tf.split(pred, (2, 2, 1, num_classes), axis=-1)
    # 將座標值偏移量做一個0~1歸一化,避免中心點在單元格之外
    box_xy = tf.sigmoid(box_xy)
    # 將取置信度做一個0~1歸一化
    objectness = tf.sigmoid(objectness)
    # 獲取91種分類的0~1歸一化
    class_probs = layers.Softmax()(class_probs)
    # 獲取座標以及寬高的預測值
    pred_box = tf.concat((box_xy, box_wh), axis=-1)  # original xywh for loss

    # 搭建一個feature map尺寸的網格,代表每個偏移量左上角的座標
    grid = tf.meshgrid(tf.range(grid_size[1]), tf.range(grid_size[0]))
    # 對該網格進行拼接,並擴展一個維度
    grid = tf.expand_dims(tf.stack(grid, axis=-1), axis=2)  # [gx, gy, 1, 2]
    # 獲取在feature map範圍內的座標值(左上角座標+偏移量),再歸一化到0~1之間
    box_xy = (box_xy + tf.cast(grid, tf.float32)) / tf.cast(grid_size, tf.float32)
    # 根據anchor的形狀比來獲取寬高,劃分單元格
    box_wh = tf.exp(box_wh) * anchors
    print('box_wh', box_wh)
    # 獲取中心點的座標
    box_x1y1 = box_xy - box_wh / 2
    box_x2y2 = box_xy + box_wh / 2
    # 根據中心點座標處理成左上角和右下角座標的形式
    x1, y1 = tf.split(box_x1y1, (1, 1), axis=-1)
    x2, y2 = tf.split(box_x2y2, (1, 1), axis=-1)
    # 限制在feature map大小範圍內
    x1 = tf.minimum(tf.maximum(x1, 0.), image_shape[1])
    y1 = tf.minimum(tf.maximum(y1, 0.), image_shape[0])
    x2 = tf.minimum(tf.maximum(x2, 0.), image_shape[1])
    y2 = tf.minimum(tf.maximum(y2, 0.), image_shape[0])
    # 爲計算IoU拼接成一個box
    bbox = tf.concat([x1, y1, x2, y2], axis=-1)

    return bbox, objectness, class_probs, pred_box

運行結果

grid_size tf.Tensor([20 20], shape=(2,), dtype=int32)
grid_size tf.Tensor([40 40], shape=(2,), dtype=int32)
grid_size tf.Tensor([80 80], shape=(2,), dtype=int32)
box_wh tf.Tensor(
[[[[[6.39551878e-01 3.97852451e-01]
    [7.99882174e-01 3.92906696e-01]
    [6.22583508e-01 1.02322054e+00]]]]]

然後是第二種尺度的確立

feature_maps = (concat_output, x2)
if isinstance(feature_maps, tuple):
    x, x_skip = feature_maps[0], feature_maps[1]

    # concat with skip connection
    x = darknet._darknet_conv(x, 256, 1)
    # 將x3進行上採樣,使得尺寸跟x2相同
    x = layers.UpSampling2D(2)(x)
    # 拼接x2和上採樣後的x3
    x = layers.Concatenate()([x, x_skip])
else:
    x = feature_maps
# 繼續提取特徵
x = darknet._darknet_conv(x, 256, 1)
x = darknet._darknet_conv(x, 256 * 2, 3)
x = darknet._darknet_conv(x, 256, 1)
x = darknet._darknet_conv(x, 256 * 2, 3)
x = darknet._darknet_conv(x, 256, 1)
# 第二個連接點
concat_output = x

x = darknet._darknet_conv(x, 256 * 2, 3)
# [batch, h, w, num_anchors * (num_class + 5)]
# 此處不使用批歸一化和激活函數,91爲分類的類別數,經過1*1的卷積變爲通道數爲288
x = darknet._darknet_conv(x, num_anchors * (91 + 5), 1, batch_norm=False)
num_anchors = len(anchor_masks[1])
# [batch, h, w, num_anchors, (num_class + 5)]
# 獲取每一個Bounding box的座標偏移,寬高,置信度,91種分類值
x = layers.Lambda(lambda x: tf.reshape(x, (-1, tf.shape(x)[1], tf.shape(x)[2],
                                                    num_anchors, 91 + 5)))(x)
# 對每一個tensor都進行運算, 獲取每一個anchor的座標框、置信度、分類概率、座標偏移量
second_out_bbox, second_out_objectness, second_out_class_probs, second_out_pred_box = layers.Lambda(
    lambda x: yolo_boxes(x, anchors[anchor_masks[1]], 91),
    name='yolo_boxes_second_out')(x)

然後是第三種尺度的確立

feature_maps = (concat_output, x1)
if isinstance(feature_maps, tuple):
    x, x_skip = feature_maps[0], feature_maps[1]

    # concat with skip connection
    x = darknet._darknet_conv(x, 128, 1)
    # 將x3進行上採樣,使得尺寸跟x1相同
    x = layers.UpSampling2D(2)(x)
    # 拼接x1和上採樣後的x3
    x = layers.Concatenate()([x, x_skip])
else:
    x = feature_maps
# 繼續提取特徵
x = darknet._darknet_conv(x, 128, 1)
x = darknet._darknet_conv(x, 128 * 2, 3)
x = darknet._darknet_conv(x, 128, 1)
x = darknet._darknet_conv(x, 128 * 2, 3)
x = darknet._darknet_conv(x, 128, 1)
# 第三個連接點
concat_output = x

x = darknet._darknet_conv(x, 128 * 2, 3)
# [batch, h, w, num_anchors * (num_class + 5)]
# 此處不使用批歸一化和激活函數,91爲分類的類別數,經過1*1的卷積變爲通道數爲288
x = darknet._darknet_conv(x, num_anchors * (91 + 5), 1, batch_norm=False)
num_anchors = len(anchor_masks[2])
# [batch, h, w, num_anchors, (num_class + 5)]
# 獲取每一個Bounding box的座標偏移,寬高,置信度,91種分類值
x = layers.Lambda(lambda x: tf.reshape(x, (-1, tf.shape(x)[1], tf.shape(x)[2],
                                           num_anchors, 91 + 5)))(x)
# 對每一個tensor都進行運算, 獲取每一個anchor的座標框、置信度、分類概率、座標偏移量
third_out_bbox, third_out_objectness, third_out_class_probs, third_out_pred_box = layers.Lambda(
    lambda x: yolo_boxes(x, anchors[anchor_masks[2]], 91),
    name='yolo_boxes_third_out')(x)

is_training = True
if is_training:
    model = models.Model(inputs=inputs, outputs=[
        [first_out_bbox, first_out_objectness, first_out_class_probs, first_out_pred_box],
        [second_out_bbox, second_out_objectness, second_out_class_probs, second_out_pred_box],
        [third_out_bbox, third_out_objectness, third_out_class_probs, third_out_pred_box]
    ])
    print(model.summary())

爲了此處能夠打印網絡結構,我們將之前的傳入改爲

x1, x2, x3 = darknet.build_darknet(inputs, "darknet")

運行結果

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_images (InputLayer)       [(None, 640, 640, 3) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 640, 640, 32) 864         input_images[0][0]               
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 640, 640, 32) 128         conv2d[0][0]                     
__________________________________________________________________________________________________
leaky_re_lu (LeakyReLU)         (None, 640, 640, 32) 0           batch_normalization[0][0]        
__________________________________________________________________________________________________
zero_padding2d (ZeroPadding2D)  (None, 641, 641, 32) 0           leaky_re_lu[0][0]                
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 320, 320, 64) 18432       zero_padding2d[0][0]             
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 320, 320, 64) 256         conv2d_1[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_1 (LeakyReLU)       (None, 320, 320, 64) 0           batch_normalization_1[0][0]      
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 320, 320, 32) 2048        leaky_re_lu_1[0][0]              
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 320, 320, 32) 128         conv2d_2[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_2 (LeakyReLU)       (None, 320, 320, 32) 0           batch_normalization_2[0][0]      
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 320, 320, 64) 18432       leaky_re_lu_2[0][0]              
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 320, 320, 64) 256         conv2d_3[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_3 (LeakyReLU)       (None, 320, 320, 64) 0           batch_normalization_3[0][0]      
__________________________________________________________________________________________________
add (Add)                       (None, 320, 320, 64) 0           leaky_re_lu_1[0][0]              
                                                                 leaky_re_lu_3[0][0]              
__________________________________________________________________________________________________
zero_padding2d_1 (ZeroPadding2D (None, 321, 321, 64) 0           add[0][0]                        
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 160, 160, 128 73728       zero_padding2d_1[0][0]           
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 160, 160, 128 512         conv2d_4[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_4 (LeakyReLU)       (None, 160, 160, 128 0           batch_normalization_4[0][0]      
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 160, 160, 64) 8192        leaky_re_lu_4[0][0]              
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 160, 160, 64) 256         conv2d_5[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_5 (LeakyReLU)       (None, 160, 160, 64) 0           batch_normalization_5[0][0]      
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 160, 160, 128 73728       leaky_re_lu_5[0][0]              
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 160, 160, 128 512         conv2d_6[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_6 (LeakyReLU)       (None, 160, 160, 128 0           batch_normalization_6[0][0]      
__________________________________________________________________________________________________
add_1 (Add)                     (None, 160, 160, 128 0           leaky_re_lu_4[0][0]              
                                                                 leaky_re_lu_6[0][0]              
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 160, 160, 64) 8192        add_1[0][0]                      
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 160, 160, 64) 256         conv2d_7[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_7 (LeakyReLU)       (None, 160, 160, 64) 0           batch_normalization_7[0][0]      
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 160, 160, 128 73728       leaky_re_lu_7[0][0]              
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 160, 160, 128 512         conv2d_8[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_8 (LeakyReLU)       (None, 160, 160, 128 0           batch_normalization_8[0][0]      
__________________________________________________________________________________________________
add_2 (Add)                     (None, 160, 160, 128 0           add_1[0][0]                      
                                                                 leaky_re_lu_8[0][0]              
__________________________________________________________________________________________________
zero_padding2d_2 (ZeroPadding2D (None, 161, 161, 128 0           add_2[0][0]                      
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 80, 80, 256)  294912      zero_padding2d_2[0][0]           
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 80, 80, 256)  1024        conv2d_9[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_9 (LeakyReLU)       (None, 80, 80, 256)  0           batch_normalization_9[0][0]      
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 80, 80, 128)  32768       leaky_re_lu_9[0][0]              
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 80, 80, 128)  512         conv2d_10[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_10 (LeakyReLU)      (None, 80, 80, 128)  0           batch_normalization_10[0][0]     
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 80, 80, 256)  294912      leaky_re_lu_10[0][0]             
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 80, 80, 256)  1024        conv2d_11[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_11 (LeakyReLU)      (None, 80, 80, 256)  0           batch_normalization_11[0][0]     
__________________________________________________________________________________________________
add_3 (Add)                     (None, 80, 80, 256)  0           leaky_re_lu_9[0][0]              
                                                                 leaky_re_lu_11[0][0]             
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 80, 80, 128)  32768       add_3[0][0]                      
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, 80, 80, 128)  512         conv2d_12[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_12 (LeakyReLU)      (None, 80, 80, 128)  0           batch_normalization_12[0][0]     
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 80, 80, 256)  294912      leaky_re_lu_12[0][0]             
__________________________________________________________________________________________________
batch_normalization_13 (BatchNo (None, 80, 80, 256)  1024        conv2d_13[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_13 (LeakyReLU)      (None, 80, 80, 256)  0           batch_normalization_13[0][0]     
__________________________________________________________________________________________________
add_4 (Add)                     (None, 80, 80, 256)  0           add_3[0][0]                      
                                                                 leaky_re_lu_13[0][0]             
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 80, 80, 128)  32768       add_4[0][0]                      
__________________________________________________________________________________________________
batch_normalization_14 (BatchNo (None, 80, 80, 128)  512         conv2d_14[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_14 (LeakyReLU)      (None, 80, 80, 128)  0           batch_normalization_14[0][0]     
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 80, 80, 256)  294912      leaky_re_lu_14[0][0]             
__________________________________________________________________________________________________
batch_normalization_15 (BatchNo (None, 80, 80, 256)  1024        conv2d_15[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_15 (LeakyReLU)      (None, 80, 80, 256)  0           batch_normalization_15[0][0]     
__________________________________________________________________________________________________
add_5 (Add)                     (None, 80, 80, 256)  0           add_4[0][0]                      
                                                                 leaky_re_lu_15[0][0]             
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 80, 80, 128)  32768       add_5[0][0]                      
__________________________________________________________________________________________________
batch_normalization_16 (BatchNo (None, 80, 80, 128)  512         conv2d_16[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_16 (LeakyReLU)      (None, 80, 80, 128)  0           batch_normalization_16[0][0]     
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 80, 80, 256)  294912      leaky_re_lu_16[0][0]             
__________________________________________________________________________________________________
batch_normalization_17 (BatchNo (None, 80, 80, 256)  1024        conv2d_17[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_17 (LeakyReLU)      (None, 80, 80, 256)  0           batch_normalization_17[0][0]     
__________________________________________________________________________________________________
add_6 (Add)                     (None, 80, 80, 256)  0           add_5[0][0]                      
                                                                 leaky_re_lu_17[0][0]             
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 80, 80, 128)  32768       add_6[0][0]                      
__________________________________________________________________________________________________
batch_normalization_18 (BatchNo (None, 80, 80, 128)  512         conv2d_18[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_18 (LeakyReLU)      (None, 80, 80, 128)  0           batch_normalization_18[0][0]     
__________________________________________________________________________________________________
conv2d_19 (Conv2D)              (None, 80, 80, 256)  294912      leaky_re_lu_18[0][0]             
__________________________________________________________________________________________________
batch_normalization_19 (BatchNo (None, 80, 80, 256)  1024        conv2d_19[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_19 (LeakyReLU)      (None, 80, 80, 256)  0           batch_normalization_19[0][0]     
__________________________________________________________________________________________________
add_7 (Add)                     (None, 80, 80, 256)  0           add_6[0][0]                      
                                                                 leaky_re_lu_19[0][0]             
__________________________________________________________________________________________________
conv2d_20 (Conv2D)              (None, 80, 80, 128)  32768       add_7[0][0]                      
__________________________________________________________________________________________________
batch_normalization_20 (BatchNo (None, 80, 80, 128)  512         conv2d_20[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_20 (LeakyReLU)      (None, 80, 80, 128)  0           batch_normalization_20[0][0]     
__________________________________________________________________________________________________
conv2d_21 (Conv2D)              (None, 80, 80, 256)  294912      leaky_re_lu_20[0][0]             
__________________________________________________________________________________________________
batch_normalization_21 (BatchNo (None, 80, 80, 256)  1024        conv2d_21[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_21 (LeakyReLU)      (None, 80, 80, 256)  0           batch_normalization_21[0][0]     
__________________________________________________________________________________________________
add_8 (Add)                     (None, 80, 80, 256)  0           add_7[0][0]                      
                                                                 leaky_re_lu_21[0][0]             
__________________________________________________________________________________________________
conv2d_22 (Conv2D)              (None, 80, 80, 128)  32768       add_8[0][0]                      
__________________________________________________________________________________________________
batch_normalization_22 (BatchNo (None, 80, 80, 128)  512         conv2d_22[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_22 (LeakyReLU)      (None, 80, 80, 128)  0           batch_normalization_22[0][0]     
__________________________________________________________________________________________________
conv2d_23 (Conv2D)              (None, 80, 80, 256)  294912      leaky_re_lu_22[0][0]             
__________________________________________________________________________________________________
batch_normalization_23 (BatchNo (None, 80, 80, 256)  1024        conv2d_23[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_23 (LeakyReLU)      (None, 80, 80, 256)  0           batch_normalization_23[0][0]     
__________________________________________________________________________________________________
add_9 (Add)                     (None, 80, 80, 256)  0           add_8[0][0]                      
                                                                 leaky_re_lu_23[0][0]             
__________________________________________________________________________________________________
conv2d_24 (Conv2D)              (None, 80, 80, 128)  32768       add_9[0][0]                      
__________________________________________________________________________________________________
batch_normalization_24 (BatchNo (None, 80, 80, 128)  512         conv2d_24[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_24 (LeakyReLU)      (None, 80, 80, 128)  0           batch_normalization_24[0][0]     
__________________________________________________________________________________________________
conv2d_25 (Conv2D)              (None, 80, 80, 256)  294912      leaky_re_lu_24[0][0]             
__________________________________________________________________________________________________
batch_normalization_25 (BatchNo (None, 80, 80, 256)  1024        conv2d_25[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_25 (LeakyReLU)      (None, 80, 80, 256)  0           batch_normalization_25[0][0]     
__________________________________________________________________________________________________
add_10 (Add)                    (None, 80, 80, 256)  0           add_9[0][0]                      
                                                                 leaky_re_lu_25[0][0]             
__________________________________________________________________________________________________
zero_padding2d_3 (ZeroPadding2D (None, 81, 81, 256)  0           add_10[0][0]                     
__________________________________________________________________________________________________
conv2d_26 (Conv2D)              (None, 40, 40, 512)  1179648     zero_padding2d_3[0][0]           
__________________________________________________________________________________________________
batch_normalization_26 (BatchNo (None, 40, 40, 512)  2048        conv2d_26[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_26 (LeakyReLU)      (None, 40, 40, 512)  0           batch_normalization_26[0][0]     
__________________________________________________________________________________________________
conv2d_27 (Conv2D)              (None, 40, 40, 256)  131072      leaky_re_lu_26[0][0]             
__________________________________________________________________________________________________
batch_normalization_27 (BatchNo (None, 40, 40, 256)  1024        conv2d_27[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_27 (LeakyReLU)      (None, 40, 40, 256)  0           batch_normalization_27[0][0]     
__________________________________________________________________________________________________
conv2d_28 (Conv2D)              (None, 40, 40, 512)  1179648     leaky_re_lu_27[0][0]             
__________________________________________________________________________________________________
batch_normalization_28 (BatchNo (None, 40, 40, 512)  2048        conv2d_28[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_28 (LeakyReLU)      (None, 40, 40, 512)  0           batch_normalization_28[0][0]     
__________________________________________________________________________________________________
add_11 (Add)                    (None, 40, 40, 512)  0           leaky_re_lu_26[0][0]             
                                                                 leaky_re_lu_28[0][0]             
__________________________________________________________________________________________________
conv2d_29 (Conv2D)              (None, 40, 40, 256)  131072      add_11[0][0]                     
__________________________________________________________________________________________________
batch_normalization_29 (BatchNo (None, 40, 40, 256)  1024        conv2d_29[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_29 (LeakyReLU)      (None, 40, 40, 256)  0           batch_normalization_29[0][0]     
__________________________________________________________________________________________________
conv2d_30 (Conv2D)              (None, 40, 40, 512)  1179648     leaky_re_lu_29[0][0]             
__________________________________________________________________________________________________
batch_normalization_30 (BatchNo (None, 40, 40, 512)  2048        conv2d_30[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_30 (LeakyReLU)      (None, 40, 40, 512)  0           batch_normalization_30[0][0]     
__________________________________________________________________________________________________
add_12 (Add)                    (None, 40, 40, 512)  0           add_11[0][0]                     
                                                                 leaky_re_lu_30[0][0]             
__________________________________________________________________________________________________
conv2d_31 (Conv2D)              (None, 40, 40, 256)  131072      add_12[0][0]                     
__________________________________________________________________________________________________
batch_normalization_31 (BatchNo (None, 40, 40, 256)  1024        conv2d_31[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_31 (LeakyReLU)      (None, 40, 40, 256)  0           batch_normalization_31[0][0]     
__________________________________________________________________________________________________
conv2d_32 (Conv2D)              (None, 40, 40, 512)  1179648     leaky_re_lu_31[0][0]             
__________________________________________________________________________________________________
batch_normalization_32 (BatchNo (None, 40, 40, 512)  2048        conv2d_32[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_32 (LeakyReLU)      (None, 40, 40, 512)  0           batch_normalization_32[0][0]     
__________________________________________________________________________________________________
add_13 (Add)                    (None, 40, 40, 512)  0           add_12[0][0]                     
                                                                 leaky_re_lu_32[0][0]             
__________________________________________________________________________________________________
conv2d_33 (Conv2D)              (None, 40, 40, 256)  131072      add_13[0][0]                     
__________________________________________________________________________________________________
batch_normalization_33 (BatchNo (None, 40, 40, 256)  1024        conv2d_33[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_33 (LeakyReLU)      (None, 40, 40, 256)  0           batch_normalization_33[0][0]     
__________________________________________________________________________________________________
conv2d_34 (Conv2D)              (None, 40, 40, 512)  1179648     leaky_re_lu_33[0][0]             
__________________________________________________________________________________________________
batch_normalization_34 (BatchNo (None, 40, 40, 512)  2048        conv2d_34[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_34 (LeakyReLU)      (None, 40, 40, 512)  0           batch_normalization_34[0][0]     
__________________________________________________________________________________________________
add_14 (Add)                    (None, 40, 40, 512)  0           add_13[0][0]                     
                                                                 leaky_re_lu_34[0][0]             
__________________________________________________________________________________________________
conv2d_35 (Conv2D)              (None, 40, 40, 256)  131072      add_14[0][0]                     
__________________________________________________________________________________________________
batch_normalization_35 (BatchNo (None, 40, 40, 256)  1024        conv2d_35[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_35 (LeakyReLU)      (None, 40, 40, 256)  0           batch_normalization_35[0][0]     
__________________________________________________________________________________________________
conv2d_36 (Conv2D)              (None, 40, 40, 512)  1179648     leaky_re_lu_35[0][0]             
__________________________________________________________________________________________________
batch_normalization_36 (BatchNo (None, 40, 40, 512)  2048        conv2d_36[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_36 (LeakyReLU)      (None, 40, 40, 512)  0           batch_normalization_36[0][0]     
__________________________________________________________________________________________________
add_15 (Add)                    (None, 40, 40, 512)  0           add_14[0][0]                     
                                                                 leaky_re_lu_36[0][0]             
__________________________________________________________________________________________________
conv2d_37 (Conv2D)              (None, 40, 40, 256)  131072      add_15[0][0]                     
__________________________________________________________________________________________________
batch_normalization_37 (BatchNo (None, 40, 40, 256)  1024        conv2d_37[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_37 (LeakyReLU)      (None, 40, 40, 256)  0           batch_normalization_37[0][0]     
__________________________________________________________________________________________________
conv2d_38 (Conv2D)              (None, 40, 40, 512)  1179648     leaky_re_lu_37[0][0]             
__________________________________________________________________________________________________
batch_normalization_38 (BatchNo (None, 40, 40, 512)  2048        conv2d_38[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_38 (LeakyReLU)      (None, 40, 40, 512)  0           batch_normalization_38[0][0]     
__________________________________________________________________________________________________
add_16 (Add)                    (None, 40, 40, 512)  0           add_15[0][0]                     
                                                                 leaky_re_lu_38[0][0]             
__________________________________________________________________________________________________
conv2d_39 (Conv2D)              (None, 40, 40, 256)  131072      add_16[0][0]                     
__________________________________________________________________________________________________
batch_normalization_39 (BatchNo (None, 40, 40, 256)  1024        conv2d_39[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_39 (LeakyReLU)      (None, 40, 40, 256)  0           batch_normalization_39[0][0]     
__________________________________________________________________________________________________
conv2d_40 (Conv2D)              (None, 40, 40, 512)  1179648     leaky_re_lu_39[0][0]             
__________________________________________________________________________________________________
batch_normalization_40 (BatchNo (None, 40, 40, 512)  2048        conv2d_40[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_40 (LeakyReLU)      (None, 40, 40, 512)  0           batch_normalization_40[0][0]     
__________________________________________________________________________________________________
add_17 (Add)                    (None, 40, 40, 512)  0           add_16[0][0]                     
                                                                 leaky_re_lu_40[0][0]             
__________________________________________________________________________________________________
conv2d_41 (Conv2D)              (None, 40, 40, 256)  131072      add_17[0][0]                     
__________________________________________________________________________________________________
batch_normalization_41 (BatchNo (None, 40, 40, 256)  1024        conv2d_41[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_41 (LeakyReLU)      (None, 40, 40, 256)  0           batch_normalization_41[0][0]     
__________________________________________________________________________________________________
conv2d_42 (Conv2D)              (None, 40, 40, 512)  1179648     leaky_re_lu_41[0][0]             
__________________________________________________________________________________________________
batch_normalization_42 (BatchNo (None, 40, 40, 512)  2048        conv2d_42[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_42 (LeakyReLU)      (None, 40, 40, 512)  0           batch_normalization_42[0][0]     
__________________________________________________________________________________________________
add_18 (Add)                    (None, 40, 40, 512)  0           add_17[0][0]                     
                                                                 leaky_re_lu_42[0][0]             
__________________________________________________________________________________________________
zero_padding2d_4 (ZeroPadding2D (None, 41, 41, 512)  0           add_18[0][0]                     
__________________________________________________________________________________________________
conv2d_43 (Conv2D)              (None, 20, 20, 1024) 4718592     zero_padding2d_4[0][0]           
__________________________________________________________________________________________________
batch_normalization_43 (BatchNo (None, 20, 20, 1024) 4096        conv2d_43[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_43 (LeakyReLU)      (None, 20, 20, 1024) 0           batch_normalization_43[0][0]     
__________________________________________________________________________________________________
conv2d_44 (Conv2D)              (None, 20, 20, 512)  524288      leaky_re_lu_43[0][0]             
__________________________________________________________________________________________________
batch_normalization_44 (BatchNo (None, 20, 20, 512)  2048        conv2d_44[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_44 (LeakyReLU)      (None, 20, 20, 512)  0           batch_normalization_44[0][0]     
__________________________________________________________________________________________________
conv2d_45 (Conv2D)              (None, 20, 20, 1024) 4718592     leaky_re_lu_44[0][0]             
__________________________________________________________________________________________________
batch_normalization_45 (BatchNo (None, 20, 20, 1024) 4096        conv2d_45[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_45 (LeakyReLU)      (None, 20, 20, 1024) 0           batch_normalization_45[0][0]     
__________________________________________________________________________________________________
add_19 (Add)                    (None, 20, 20, 1024) 0           leaky_re_lu_43[0][0]             
                                                                 leaky_re_lu_45[0][0]             
__________________________________________________________________________________________________
conv2d_46 (Conv2D)              (None, 20, 20, 512)  524288      add_19[0][0]                     
__________________________________________________________________________________________________
batch_normalization_46 (BatchNo (None, 20, 20, 512)  2048        conv2d_46[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_46 (LeakyReLU)      (None, 20, 20, 512)  0           batch_normalization_46[0][0]     
__________________________________________________________________________________________________
conv2d_47 (Conv2D)              (None, 20, 20, 1024) 4718592     leaky_re_lu_46[0][0]             
__________________________________________________________________________________________________
batch_normalization_47 (BatchNo (None, 20, 20, 1024) 4096        conv2d_47[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_47 (LeakyReLU)      (None, 20, 20, 1024) 0           batch_normalization_47[0][0]     
__________________________________________________________________________________________________
add_20 (Add)                    (None, 20, 20, 1024) 0           add_19[0][0]                     
                                                                 leaky_re_lu_47[0][0]             
__________________________________________________________________________________________________
conv2d_48 (Conv2D)              (None, 20, 20, 512)  524288      add_20[0][0]                     
__________________________________________________________________________________________________
batch_normalization_48 (BatchNo (None, 20, 20, 512)  2048        conv2d_48[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_48 (LeakyReLU)      (None, 20, 20, 512)  0           batch_normalization_48[0][0]     
__________________________________________________________________________________________________
conv2d_49 (Conv2D)              (None, 20, 20, 1024) 4718592     leaky_re_lu_48[0][0]             
__________________________________________________________________________________________________
batch_normalization_49 (BatchNo (None, 20, 20, 1024) 4096        conv2d_49[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_49 (LeakyReLU)      (None, 20, 20, 1024) 0           batch_normalization_49[0][0]     
__________________________________________________________________________________________________
add_21 (Add)                    (None, 20, 20, 1024) 0           add_20[0][0]                     
                                                                 leaky_re_lu_49[0][0]             
__________________________________________________________________________________________________
conv2d_50 (Conv2D)              (None, 20, 20, 512)  524288      add_21[0][0]                     
__________________________________________________________________________________________________
batch_normalization_50 (BatchNo (None, 20, 20, 512)  2048        conv2d_50[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_50 (LeakyReLU)      (None, 20, 20, 512)  0           batch_normalization_50[0][0]     
__________________________________________________________________________________________________
conv2d_51 (Conv2D)              (None, 20, 20, 1024) 4718592     leaky_re_lu_50[0][0]             
__________________________________________________________________________________________________
batch_normalization_51 (BatchNo (None, 20, 20, 1024) 4096        conv2d_51[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_51 (LeakyReLU)      (None, 20, 20, 1024) 0           batch_normalization_51[0][0]     
__________________________________________________________________________________________________
add_22 (Add)                    (None, 20, 20, 1024) 0           add_21[0][0]                     
                                                                 leaky_re_lu_51[0][0]             
__________________________________________________________________________________________________
conv2d_52 (Conv2D)              (None, 20, 20, 512)  524288      add_22[0][0]                     
__________________________________________________________________________________________________
batch_normalization_52 (BatchNo (None, 20, 20, 512)  2048        conv2d_52[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_52 (LeakyReLU)      (None, 20, 20, 512)  0           batch_normalization_52[0][0]     
__________________________________________________________________________________________________
conv2d_53 (Conv2D)              (None, 20, 20, 1024) 4718592     leaky_re_lu_52[0][0]             
__________________________________________________________________________________________________
batch_normalization_53 (BatchNo (None, 20, 20, 1024) 4096        conv2d_53[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_53 (LeakyReLU)      (None, 20, 20, 1024) 0           batch_normalization_53[0][0]     
__________________________________________________________________________________________________
conv2d_54 (Conv2D)              (None, 20, 20, 512)  524288      leaky_re_lu_53[0][0]             
__________________________________________________________________________________________________
batch_normalization_54 (BatchNo (None, 20, 20, 512)  2048        conv2d_54[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_54 (LeakyReLU)      (None, 20, 20, 512)  0           batch_normalization_54[0][0]     
__________________________________________________________________________________________________
conv2d_55 (Conv2D)              (None, 20, 20, 1024) 4718592     leaky_re_lu_54[0][0]             
__________________________________________________________________________________________________
batch_normalization_55 (BatchNo (None, 20, 20, 1024) 4096        conv2d_55[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_55 (LeakyReLU)      (None, 20, 20, 1024) 0           batch_normalization_55[0][0]     
__________________________________________________________________________________________________
conv2d_56 (Conv2D)              (None, 20, 20, 512)  524288      leaky_re_lu_55[0][0]             
__________________________________________________________________________________________________
batch_normalization_56 (BatchNo (None, 20, 20, 512)  2048        conv2d_56[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_56 (LeakyReLU)      (None, 20, 20, 512)  0           batch_normalization_56[0][0]     
__________________________________________________________________________________________________
conv2d_59 (Conv2D)              (None, 20, 20, 256)  131072      leaky_re_lu_56[0][0]             
__________________________________________________________________________________________________
batch_normalization_58 (BatchNo (None, 20, 20, 256)  1024        conv2d_59[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_58 (LeakyReLU)      (None, 20, 20, 256)  0           batch_normalization_58[0][0]     
__________________________________________________________________________________________________
up_sampling2d (UpSampling2D)    (None, 40, 40, 256)  0           leaky_re_lu_58[0][0]             
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 40, 40, 768)  0           up_sampling2d[0][0]              
                                                                 add_18[0][0]                     
__________________________________________________________________________________________________
conv2d_60 (Conv2D)              (None, 40, 40, 256)  196608      concatenate[0][0]                
__________________________________________________________________________________________________
batch_normalization_59 (BatchNo (None, 40, 40, 256)  1024        conv2d_60[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_59 (LeakyReLU)      (None, 40, 40, 256)  0           batch_normalization_59[0][0]     
__________________________________________________________________________________________________
conv2d_61 (Conv2D)              (None, 40, 40, 512)  1179648     leaky_re_lu_59[0][0]             
__________________________________________________________________________________________________
batch_normalization_60 (BatchNo (None, 40, 40, 512)  2048        conv2d_61[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_60 (LeakyReLU)      (None, 40, 40, 512)  0           batch_normalization_60[0][0]     
__________________________________________________________________________________________________
conv2d_62 (Conv2D)              (None, 40, 40, 256)  131072      leaky_re_lu_60[0][0]             
__________________________________________________________________________________________________
batch_normalization_61 (BatchNo (None, 40, 40, 256)  1024        conv2d_62[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_61 (LeakyReLU)      (None, 40, 40, 256)  0           batch_normalization_61[0][0]     
__________________________________________________________________________________________________
conv2d_63 (Conv2D)              (None, 40, 40, 512)  1179648     leaky_re_lu_61[0][0]             
__________________________________________________________________________________________________
batch_normalization_62 (BatchNo (None, 40, 40, 512)  2048        conv2d_63[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_62 (LeakyReLU)      (None, 40, 40, 512)  0           batch_normalization_62[0][0]     
__________________________________________________________________________________________________
conv2d_64 (Conv2D)              (None, 40, 40, 256)  131072      leaky_re_lu_62[0][0]             
__________________________________________________________________________________________________
batch_normalization_63 (BatchNo (None, 40, 40, 256)  1024        conv2d_64[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_63 (LeakyReLU)      (None, 40, 40, 256)  0           batch_normalization_63[0][0]     
__________________________________________________________________________________________________
conv2d_67 (Conv2D)              (None, 40, 40, 128)  32768       leaky_re_lu_63[0][0]             
__________________________________________________________________________________________________
batch_normalization_65 (BatchNo (None, 40, 40, 128)  512         conv2d_67[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_65 (LeakyReLU)      (None, 40, 40, 128)  0           batch_normalization_65[0][0]     
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 80, 80, 128)  0           leaky_re_lu_65[0][0]             
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 80, 80, 384)  0           up_sampling2d_1[0][0]            
                                                                 add_10[0][0]                     
__________________________________________________________________________________________________
conv2d_68 (Conv2D)              (None, 80, 80, 128)  49152       concatenate_1[0][0]              
__________________________________________________________________________________________________
batch_normalization_66 (BatchNo (None, 80, 80, 128)  512         conv2d_68[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_66 (LeakyReLU)      (None, 80, 80, 128)  0           batch_normalization_66[0][0]     
__________________________________________________________________________________________________
conv2d_69 (Conv2D)              (None, 80, 80, 256)  294912      leaky_re_lu_66[0][0]             
__________________________________________________________________________________________________
batch_normalization_67 (BatchNo (None, 80, 80, 256)  1024        conv2d_69[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_67 (LeakyReLU)      (None, 80, 80, 256)  0           batch_normalization_67[0][0]     
__________________________________________________________________________________________________
conv2d_70 (Conv2D)              (None, 80, 80, 128)  32768       leaky_re_lu_67[0][0]             
__________________________________________________________________________________________________
batch_normalization_68 (BatchNo (None, 80, 80, 128)  512         conv2d_70[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_68 (LeakyReLU)      (None, 80, 80, 128)  0           batch_normalization_68[0][0]     
__________________________________________________________________________________________________
conv2d_71 (Conv2D)              (None, 80, 80, 256)  294912      leaky_re_lu_68[0][0]             
__________________________________________________________________________________________________
batch_normalization_69 (BatchNo (None, 80, 80, 256)  1024        conv2d_71[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_69 (LeakyReLU)      (None, 80, 80, 256)  0           batch_normalization_69[0][0]     
__________________________________________________________________________________________________
conv2d_72 (Conv2D)              (None, 80, 80, 128)  32768       leaky_re_lu_69[0][0]             
__________________________________________________________________________________________________
batch_normalization_70 (BatchNo (None, 80, 80, 128)  512         conv2d_72[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_70 (LeakyReLU)      (None, 80, 80, 128)  0           batch_normalization_70[0][0]     
__________________________________________________________________________________________________
conv2d_57 (Conv2D)              (None, 20, 20, 1024) 4718592     leaky_re_lu_56[0][0]             
__________________________________________________________________________________________________
conv2d_65 (Conv2D)              (None, 40, 40, 512)  1179648     leaky_re_lu_63[0][0]             
__________________________________________________________________________________________________
conv2d_73 (Conv2D)              (None, 80, 80, 256)  294912      leaky_re_lu_70[0][0]             
__________________________________________________________________________________________________
batch_normalization_57 (BatchNo (None, 20, 20, 1024) 4096        conv2d_57[0][0]                  
__________________________________________________________________________________________________
batch_normalization_64 (BatchNo (None, 40, 40, 512)  2048        conv2d_65[0][0]                  
__________________________________________________________________________________________________
batch_normalization_71 (BatchNo (None, 80, 80, 256)  1024        conv2d_73[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_57 (LeakyReLU)      (None, 20, 20, 1024) 0           batch_normalization_57[0][0]     
__________________________________________________________________________________________________
leaky_re_lu_64 (LeakyReLU)      (None, 40, 40, 512)  0           batch_normalization_64[0][0]     
__________________________________________________________________________________________________
leaky_re_lu_71 (LeakyReLU)      (None, 80, 80, 256)  0           batch_normalization_71[0][0]     
__________________________________________________________________________________________________
conv2d_58 (Conv2D)              (None, 20, 20, 288)  295200      leaky_re_lu_57[0][0]             
__________________________________________________________________________________________________
conv2d_66 (Conv2D)              (None, 40, 40, 288)  147744      leaky_re_lu_64[0][0]             
__________________________________________________________________________________________________
conv2d_74 (Conv2D)              (None, 80, 80, 288)  74016       leaky_re_lu_71[0][0]             
__________________________________________________________________________________________________
lambda (Lambda)                 (None, None, None, 3 0           conv2d_58[0][0]                  
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, None, None, 3 0           conv2d_66[0][0]                  
__________________________________________________________________________________________________
lambda_2 (Lambda)               (None, None, None, 3 0           conv2d_74[0][0]                  
__________________________________________________________________________________________________
yolo_boxes_first_out (Lambda)   ((None, None, None,  0           lambda[0][0]                     
__________________________________________________________________________________________________
yolo_boxes_second_out (Lambda)  ((None, None, None,  0           lambda_1[0][0]                   
__________________________________________________________________________________________________
yolo_boxes_third_out (Lambda)   ((None, None, None,  0           lambda_2[0][0]                   
==================================================================================================
Total params: 62,060,992
Trainable params: 62,008,384
Non-trainable params: 52,608
__________________________________________________________________________________________________
None

現在我們對各層的輸出進行非極大值抑制。我們先將搭建網絡的輸入改爲圖片

x1, x2, x3 = darknet.build_darknet(img, "darknet")
else:
    outputs = layers.Lambda(lambda x: yolo_nms(x, 91),
                            name='yolo_nms')([
        # boxes_0[:3], boxes_1[:3], boxes_2[:3]
        [first_out_bbox, first_out_objectness, first_out_class_probs],
        [second_out_bbox, second_out_objectness, second_out_class_probs],
        [third_out_bbox, third_out_objectness, third_out_class_probs]
    ])
    model = models.Model(inputs=inputs, outputs=outputs)
    print(model.summary())

yolo_nms代碼如下

def yolo_nms(yolo_pred, num_class):
    """ 對邊框做非極大抑制
    """
    boxes, objectness, class_probs = [], [], []

    # pred: [bbox, objectness, class_probs]
    # 堆疊三個特徵點的所有邊框
    for pred in yolo_pred:
        # boxes: [batch, -1, 4]
        boxes.append(tf.reshape(pred[0], (tf.shape(pred[0])[0], -1, tf.shape(pred[0])[-1])))
        # objectness: [batch, -1, 1]
        objectness.append(tf.reshape(pred[1], (tf.shape(pred[1])[0], -1, tf.shape(pred[1])[-1])))
        # class_probs: [batch, -1, num_classes]
        class_probs.append(tf.reshape(pred[2], (tf.shape(pred[2])[0], -1, tf.shape(pred[2])[-1])))
    # 這裏concat在axis=1
    bbox = tf.concat(boxes, axis=1)
    objectness = tf.concat(objectness, axis=1)
    class_probs = tf.concat(class_probs, axis=1)

    final_batch_nms_bboxes = []
    final_batch_nms_scores = []
    final_batch_nms_classes = []
    valid_detection_nums = []
    batch_size = 1
    for b in range(batch_size):
        # 置信度*類別概率作爲最終nms的排序依據
        cur_scores = objectness[b] * class_probs[b]

        # test模式下,batch緯度都是1了, 源碼是直接squeeze因爲test的batch=1
        # dscores = tf.squeeze(scores, axis=0)
        cur_dscores = tf.reshape(cur_scores, (-1, num_class))
        cur_bbox = tf.reshape(bbox[b], (-1, 4))

        # for i in range(num_class):
        #     cur_dscores_cls = cur_dscores[:, i]

        # 取所有類別中概率最大的,取每一行的概率最大值
        cur_scores = tf.reduce_max(cur_dscores, [1])
        print('cur_scores', cur_scores)
        # 取每一個概率最大值的索引
        cur_classes = tf.argmax(cur_dscores, 1)
        print('cur_classes', cur_classes)
        # 非極大值抑制,輸出100個索引號和分類分數
        selected_indices, selected_scores = tf.image.non_max_suppression_with_scores(
            boxes=cur_bbox,
            scores=cur_scores,
            max_output_size=100,
            iou_threshold=0.5,
            score_threshold=0.007,
            soft_nms_sigma=0.5
        )
        print('selected_indices', selected_indices)
        print('selected_scores', selected_scores)
        # num_valid_nms_boxes = tf.shape(selected_indices)[0]
        # pad_num = self.yolo_max_boxes - num_valid_nms_boxes
        # 數量不夠的話做padding
        # selected_indices = tf.concat([selected_indices, tf.zeros(self.yolo_max_boxes - num_valid_nms_boxes, tf.int32)],
        #                              0)
        # selected_scores = tf.concat([selected_scores, tf.zeros(self.yolo_max_boxes - num_valid_nms_boxes, tf.float32)],
        #                             -1)
        # 非極大值抑制後的有效數量
        vaild_num = tf.shape(selected_indices)[0]
        valid_detection_nums.append(vaild_num)
        pad_num = 100 - vaild_num

        # [N, (x1, y1, x2, y2)]
        # 挑選非極大值抑制後的anchorbox
        cur_bbox = tf.gather(cur_bbox, selected_indices)
        # 在anchorbox的下方填充多行0
        cur_bbox = tf.pad(cur_bbox, [[0, pad_num], [0, 0]])
        cur_bbox = tf.expand_dims(cur_bbox, axis=0)
        print('cur_bbox', cur_bbox)
        final_batch_nms_bboxes.append(cur_bbox)

        # [1, N]
        # 挑選非極大值抑制後的分類最大評分
        cur_scores = selected_scores
        cur_scores = tf.pad(cur_scores, [[0, pad_num]])
        cur_scores = tf.expand_dims(cur_scores, axis=0)
        print('cur_scores', cur_scores)
        final_batch_nms_scores.append(cur_scores)

        # [1, N]
        # 挑選非極大值抑制後的分類最大概率索引
        cur_classes = tf.gather(cur_classes, selected_indices)
        cur_classes = tf.pad(cur_classes, [[0, pad_num]])
        cur_classes = tf.expand_dims(cur_classes, axis=0)
        print('cur_classes', cur_classes)
        final_batch_nms_classes.append(cur_classes)

    final_batch_nms_bboxes = tf.concat(final_batch_nms_bboxes, axis=0)
    final_batch_nms_scores = tf.concat(final_batch_nms_scores, axis=0)
    final_batch_nms_classes = tf.concat(final_batch_nms_classes, axis=0)

運行結果

cur_scores tf.Tensor([0.00585971 0.00599038 0.00595966 ... 0.00554877 0.00553788 0.00553541], shape=(25200,), dtype=float32)
cur_classes tf.Tensor([45 47 24 ... 34  9 80], shape=(25200,), dtype=int64)
selected_indices tf.Tensor(
[4866  818 4764 4500 4662 4398  476 4728 2961 4176 4242 5097 3102 5118
 2823 4026 3444  321 4344 5028 3702 2484 2922 4686 2625 2781 3921 3306
 2445 4308 4833 2346 3768  723 4632 2304 4206], shape=(37,), dtype=int32)
selected_scores tf.Tensor(
[0.00931381 0.00896284 0.00856594 0.00833606 0.00819792 0.00801256
 0.00801134 0.00800271 0.00784801 0.00777876 0.00775182 0.00774399
 0.00770376 0.00768892 0.00760151 0.00760061 0.00757784 0.00753411
 0.00752614 0.00748221 0.00747792 0.00742623 0.00737575 0.00737422
 0.00735858 0.00733442 0.00733046 0.00730357 0.00720255 0.00719151
 0.00717866 0.00716299 0.0071443  0.00714163 0.00708044 0.0070156
 0.00700225], shape=(37,), dtype=float32)
cur_bbox tf.Tensor(
[[[0.47525817 0.7058009  0.65038186 0.8253993 ]
  [0.2290636  0.43899518 1.0120833  0.90559846]
  [0.6245662  0.67990446 0.8013499  0.8008765 ]
  [0.425923   0.6314819  0.5999857  0.74891996]
  [0.7734392  0.655851   0.9523676  0.774408  ]
  [0.57445204 0.6052354  0.7513912  0.72489446]
  [0.49578106 0.13718146 1.3462695  0.61382943]
  [0.32203868 0.68235976 0.5037507  0.79731053]
  [0.59745276 0.3060625  0.77827334 0.423403  ]
  [0.7227372  0.5557441  0.9029784  0.67373496]
  [0.27246937 0.58118933 0.4531304  0.6981532 ]
  [0.38823918 0.7577431  0.5873878  0.8722745 ]
  [0.77167207 0.33106077 0.9540772  0.44833946]
  [0.56604093 0.7558559  0.75961024 0.8745059 ]
  [0.44609636 0.28142396 0.6295416  0.39770314]
  [0.47225362 0.5309098  0.65343183 0.6481316 ]
  [0.620741   0.4065663  0.8049351  0.52262986]
  [0.22502528 0.         0.523556   0.5531311 ]
  [0.11964019 0.6067808  0.3058835  0.72229236]
  [0.81782186 0.729955   1.007774   0.84961873]
  [0.7708911  0.45628837 0.9547696  0.5728136 ]
  [0.6212652  0.20621267 0.8043301  0.32260957]
  [0.2703642  0.30675125 0.45527455 0.42211968]
  [0.         0.68376225 0.1580659  0.7943612 ]
  [0.79585207 0.23102117 0.9797478  0.3477431 ]
  [0.09455924 0.2818944  0.28103435 0.396766  ]
  [0.59682685 0.50606334 0.7788772  0.62299573]
  [0.4687725  0.38219234 0.65683836 0.49669072]
  [0.2954356  0.2064167  0.4800799  0.3222209 ]
  [0.82373786 0.5792823  1.0020474  0.7000315 ]
  [0.18977335 0.7078563  0.3858095  0.82072246]
  [0.46975034 0.1813986  0.6558301  0.29699862]
  [0.31987333 0.48147443 0.50574625 0.59682184]
  [0.         0.34287256 0.22398928 0.9016189 ]
  [0.52763104 0.6548126  0.69821584 0.7761345 ]
  [0.11914106 0.1818139  0.3063051  0.29622966]
  [0.         0.5828472  0.15521356 0.6949684 ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]]], shape=(1, 100, 4), dtype=float32)
cur_scores tf.Tensor(
[[0.00931381 0.00896284 0.00856594 0.00833606 0.00819792 0.00801256
  0.00801134 0.00800271 0.00784801 0.00777876 0.00775182 0.00774399
  0.00770376 0.00768892 0.00760151 0.00760061 0.00757784 0.00753411
  0.00752614 0.00748221 0.00747792 0.00742623 0.00737575 0.00737422
  0.00735858 0.00733442 0.00733046 0.00730357 0.00720255 0.00719151
  0.00717866 0.00716299 0.0071443  0.00714163 0.00708044 0.0070156
  0.00700225 0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]], shape=(1, 100), dtype=float32)
cur_classes tf.Tensor(
[[29 28 29 29 29 29 39 29 29 29 29 29 29 29 29 29 29 27 29 29 29 29 29 29
  29 29 29 29 29 29 29 29 29 20 29 29 29  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0]], shape=(1, 100), dtype=int64)

 基於COCO數據集進行訓練

這裏我們將傳入改爲tensorflow的layers.Input

x1, x2, x3 = darknet.build_darknet(inputs, "darknet")

要使用COCO數據集,我們需要安裝COCO的工具包

pip install pycocotools==2.0.0
optimizer = optimizers.Adam(learning_rate=0.001)
train_data = CoCoDataGenrator(
    coco_annotation_file="./data/instances_val2017.json",
    img_shape=[640, 640, 3],
    batch_size=5,
    max_instances=100
)

我們先來看一下instances_val2017.json包含的格式:ino、licenses,這兩項不重要,可以爲空。images、annotations代表圖像和標註。images爲list,每個list項是一個dict,其中一個圖片的內容是(coco_url代表圖像的url地址,height、width代表圖像的高和寬):

"images": [{"license": 4,"file_name": "000000397133.jpg","coco_url": "http://images.cocodataset.org/val2017/000000397133.jpg",
"height": 427,
"width": 640,"date_captured": "2013-11-14 17:02:52","flickr_url": "http://farm7.staticflickr.com/6116/6255196340_da26cf2c9e_z.jpg","id": 397133}]

annotations爲list,每個list項是一個dict,其中一個圖片的內容是:

"annotations": [{"segmentation": [[510.66,423.01,511.72,420.03,510.45,416.0,510.34,413.02,510.77,410.26,510.77,407.5,510.34,405.16,511.51,402.83,511.41,400.49,510.24,398.16,509.39,397.31,504.61,399.22,502.17,399.64,500.89,401.66,500.47,402.08,499.09,401.87,495.79,401.98,490.59,401.77,488.79,401.77,485.39,398.58,483.9,397.31,481.56,396.35,478.48,395.93,476.68,396.03,475.4,396.77,473.92,398.79,473.28,399.96,473.49,401.87,474.56,403.47,473.07,405.59,473.39,407.71,476.68,409.41,479.23,409.73,481.56,410.69,480.4,411.85,481.35,414.93,479.86,418.65,477.32,420.03,476.04,422.58,479.02,422.58,480.29,423.01,483.79,419.93,486.66,416.21,490.06,415.57,492.18,416.85,491.65,420.24,492.82,422.9,493.56,424.39,496.43,424.6,498.02,423.01,498.13,421.31,497.07,420.03,497.07,415.15,496.33,414.51,501.1,411.96,502.06,411.32,503.02,415.04,503.33,418.12,501.1,420.24,498.98,421.63,500.47,424.39,505.03,423.32,506.2,421.31,507.69,419.5,506.31,423.32,510.03,423.01,510.45,423.01]],
"area": 702.1057499999998,
"iscrowd": 0,
"image_id": 289343,
"bbox": [473.07,395.93,38.65,28.67],
"category_id": 18,
"id": 1768}]

image_id是對應的images的id,一個image可能會有多個annotation,因爲每個annotation只是表示一個目標的label。category_id是類別的標識。segmentation是語義分割的label,area表示語義分割的區域大小,iscrowd表示是否是人羣,bbox是目標框。

CoCoDataGenrator類代碼如下

import cv2
from pycocotools.coco import COCO
import numpy as np
import skimage.io as io


class CoCoDataGenrator:
    def __init__(self,
                 coco_annotation_file,
                 img_shape=(640, 640, 3),
                 batch_size=1,
                 max_instances=100,
                 include_crowd=False,
                 include_mask=False,
                 include_keypoint=False):
        self.img_shape = img_shape
        self.batch_size = batch_size
        self.max_instances = max_instances
        self.include_crowd = include_crowd
        self.include_mask = include_mask
        self.include_keypoint = include_keypoint

        self.current_batch_index = 0
        self.total_batch_size = 0
        self.img_ids = []
        self.coco = COCO(annotation_file=coco_annotation_file)
        self.load_data()

    def load_data(self):
        # 初步過濾數據是否包含crowd
        target_img_ids = []
        for k in self.coco.imgToAnns:
            annos = self.coco.imgToAnns[k]
            print(annos)
            if annos:
                annos = list(filter(lambda x: x['iscrowd'] == self.include_crowd, annos))
                if annos:
                    target_img_ids.append(k)
        self.total_batch_size = len(target_img_ids) // self.batch_size
        self.img_ids = target_img_ids

    def next_batch(self):
        if self.current_batch_index >= self.total_batch_size:
            self.current_batch_index = 0
            self._on_epoch_end()

        batch_img_ids = self.img_ids[self.current_batch_index * self.batch_size:
                                     (self.current_batch_index + 1) * self.batch_size]
        batch_imgs = []
        batch_bboxes = []
        batch_labels = []
        batch_masks = []
        batch_keypoints = []
        for img_id in batch_img_ids:
            # {"img":, "bboxes":, "labels":, "masks":, "key_points":}
            data = self._data_generation(image_id=img_id)
            if len(np.shape(data['img'])) > 0:
                batch_imgs.append(data['img'])

                if len(data['labels']) > self.max_instances:
                    batch_bboxes.append(data['bboxes'][:self.max_instances, :])
                    batch_labels.append(data['labels'][:self.max_instances])
                else:
                    pad_num = self.max_instances - len(data['labels'])
                    batch_bboxes.append(np.pad(data['bboxes'], [(0,pad_num), (0, 0)]))
                    batch_labels.append(np.pad(data['labels'], [(0,pad_num)]))

                if self.include_mask:
                    batch_masks.append(data['masks'])

                if self.include_keypoint:
                    batch_keypoints.append(data['keypoints'])

        self.current_batch_index += 1

        if len(batch_imgs) < self.batch_size:
            return self.next_batch()

        output = {
            'imgs': np.array(batch_imgs, dtype=np.int32),
            'bboxes': np.array(batch_bboxes, dtype=np.int16),
            'labels': np.array(batch_labels, dtype=np.int8),
            'masks': np.array(batch_masks, dtype=np.int8),
            'keypoints': np.array(batch_keypoints, dtype=np.int16)
        }

        return output

    def _on_epoch_end(self):
        np.random.shuffle(self.img_ids)

    def _resize_im(self, origin_im, bboxes):
        """ 對圖片/mask/box resize

        :param origin_im
        :param bboxes
        :return im_blob: [h, w, 3]
                gt_boxes: [N, [ymin, xmin, ymax, xmax]]
        """
        im_shape = np.shape(origin_im)
        im_size_max = np.max(im_shape[0:2])
        im_scale = float(self.img_shape[0]) / float(im_size_max)

        # resize原始圖片
        im_resize = cv2.resize(origin_im, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR)
        im_resize_shape = np.shape(im_resize)
        im_blob = np.zeros(self.img_shape, dtype=np.float32)
        im_blob[0:im_resize_shape[0], 0:im_resize_shape[1], :] = im_resize

        # resize對應邊框
        bboxes_resize = np.array(bboxes * im_scale, dtype=np.int16)

        return im_blob, bboxes_resize

    def _resize_mask(self, origin_masks):
        """ resize mask數據
        :param origin_mask:
        :return: mask_resize: [h, w, instance]
                 gt_boxes: [N, [ymin, xmin, ymax, xmax]]
        """
        mask_shape = np.shape(origin_masks)
        mask_size_max = np.max(mask_shape[0:2])
        im_scale = float(self.img_shape[0]) / float(mask_size_max)

        # resize mask/box
        gt_boxes = []
        masks_resize = []
        for m in origin_masks:
            m = np.array(m, dtype=np.float32)
            m_resize = cv2.resize(m, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR)
            m_resize = np.array(m_resize >= 0.5, dtype=np.int8)

            # 計算bdbox
            h, w = np.shape(m_resize)
            rows, cols = np.where(m_resize)
            # [xmin, ymin, xmax, ymax]
            xmin = np.min(cols) if np.min(cols) >= 0 else 0
            ymin = np.min(rows) if np.min(rows) >= 0 else 0
            xmax = np.max(cols) if np.max(cols) <= w else w
            ymax = np.max(rows) if np.max(rows) <= h else h
            bdbox = [xmin, ymin, xmax, ymax]
            gt_boxes.append(bdbox)

            mask_blob = np.zeros((self.img_shape[0], self.img_shape[1], 1), dtype=np.float32)
            mask_blob[0:h, 0:w, 0] = m_resize
            masks_resize.append(mask_blob)

        # [instance_num, [ymin, xmin, ymax, xmax]]
        gt_boxes = np.array(gt_boxes, dtype=np.int16)
        # [h, w, instance_num]
        masks_resize = np.concatenate(masks_resize, axis=-1)

        return masks_resize, gt_boxes

    def _data_generation(self, image_id):
        """ 拉取coco標記數據, 目標邊框, 類別, mask
        :param image_id:
        :return:
        """

        anno_ids = self.coco.getAnnIds(imgIds=image_id, iscrowd=self.include_crowd)
        bboxes = []
        labels = []
        masks = []
        keypoints = []
        for i in anno_ids:
            # 邊框, 處理成左上右下座標
            ann = self.coco.anns[i]
            bbox = ann['bbox']
            xmin, ymin, w, h = bbox
            xmin = int(xmin)
            ymin = int(ymin)
            xmax = int(xmin + w)
            ymax = int(ymin + h)
            bboxes.append([xmin, ymin, xmax, ymax])
            # 類別ID
            label = ann['category_id']
            labels.append(label)
            # 實例分割
            if self.include_mask:
                mask = self.coco.annToMask(ann)
                masks.append(mask)
            if self.include_keypoint and ann.get('keypoints'):
                keypoint = ann['keypoints']
                # 處理成[x,y,v] 其中v=0表示沒有此點,v=1表示被擋不可見,v=2表示可見
                keypoint = np.reshape(keypoint, [-1, 3])
                keypoints.append(keypoint)

        # 輸出包含5個東西, 不需要則爲空
        outputs = {
            "img": [],
            "labels": [],
            "bboxes": [],
            "masks": [],
            "keypoints": []
        }

        # 處理最終數據 mask
        if self.include_mask:
            # [N, h, w]
            masks, bboxes = self._resize_mask(origin_masks=masks)
            outputs['masks'] = masks
            outputs['bboxes'] = bboxes

        # 處理最終數據 keypoint
        if self.include_keypoint:
            keypoints = np.array(keypoints, dtype=np.int8)
            outputs['keypoints'] = keypoints

        img = io.imread(self.coco.imgs[image_id]['coco_url'])
        if len(np.shape(img)) < 2:
            return outputs
        elif len(np.shape(img)) == 2:
            img = np.expand_dims(img, axis=-1)
            img = np.pad(img, [(0,0), (0,0), (0,2)])

        labels = np.array(labels, dtype=np.int8)
        bboxes = np.array(bboxes, dtype=np.int16)
        img_resize, bboxes_resize = self._resize_im(origin_im=img, bboxes=bboxes)
        outputs['img'] = img_resize
        outputs['labels'] = labels
        outputs['bboxes'] = bboxes_resize

        return outputs

現在我們開始準備訓練COCO數據集

# 獲取數據集的分類類別
classes = train_data.coco.cats
log_dir = "./logs"
summary_writer = tf.summary.create_file_writer(log_dir)
epochs = 101
for epoch in range(epochs):
    if epoch % 20 == 0 and epoch != 0:
        model.save_weights(log_dir + '/yolov3-tf-{}.h5'.format(epoch))
    for batch in range(train_data.total_batch_size):
        with tf.GradientTape() as tape:
            data = train_data.next_batch()
            # 獲取樣本的圖像,邊框,標籤數據
            gt_imgs = data['imgs'] / 255.
            gt_boxes = data['bboxes'] / image_shape[0]
            gt_classes = data['labels']
            print('gt_imgs', gt_imgs)
            print('gt_boxes', gt_boxes)
            print('gt_classes', gt_classes)
            # 構建YOLO訓練所需要的目標值
            yolo_targets = transform_targets(
                gt_boxes=gt_boxes,
                gt_lables=gt_classes,
                anchors=anchors,
                anchor_masks=anchor_masks,
                im_size=image_shape[0]
            )
            yolo_preds = model(gt_imgs, training=True)

運行結果

gt_imgs [[[[0.15294118 0.15294118 0.15294118]
   [0.09803922 0.09803922 0.09803922]
   [0.08235294 0.08235294 0.08235294]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.14117647 0.14117647 0.14117647]
   [0.14117647 0.14117647 0.14117647]
   [0.12156863 0.12156863 0.12156863]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.07058824 0.07058824 0.07058824]
   [0.09803922 0.09803922 0.09803922]
   [0.10196078 0.10196078 0.10196078]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  ...

  [[0.13333333 0.13333333 0.13333333]
   [0.1254902  0.1254902  0.1254902 ]
   [0.14509804 0.14509804 0.14509804]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.1254902  0.1254902  0.1254902 ]
   [0.14901961 0.14901961 0.14901961]
   [0.16862745 0.16862745 0.16862745]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.14509804 0.14509804 0.14509804]
   [0.16470588 0.16470588 0.16470588]
   [0.15294118 0.15294118 0.15294118]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]]


 [[[0.89803922 0.89803922 0.89803922]
   [0.89411765 0.89411765 0.89411765]
   [0.89411765 0.89411765 0.89411765]
   ...
   [0.78039216 0.78039216 0.77254902]
   [0.83921569 0.8627451  0.85490196]
   [0.91764706 0.94117647 0.94117647]]

  [[0.89803922 0.89803922 0.89803922]
   [0.89411765 0.89411765 0.89411765]
   [0.89411765 0.89411765 0.89411765]
   ...
   [0.79607843 0.79607843 0.78823529]
   [0.84705882 0.87058824 0.8627451 ]
   [0.90588235 0.92941176 0.92941176]]

  [[0.89803922 0.89803922 0.89803922]
   [0.89411765 0.89411765 0.89411765]
   [0.89411765 0.89411765 0.89411765]
   ...
   [0.80392157 0.81176471 0.8       ]
   [0.85490196 0.87843137 0.87058824]
   [0.89411765 0.90980392 0.91372549]]

  ...

  [[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]]


 [[[0.49803922 0.70196078 0.17647059]
   [0.49019608 0.69411765 0.16862745]
   [0.47843137 0.68235294 0.15686275]
   ...
   [0.50588235 0.70588235 0.15686275]
   [0.51764706 0.71764706 0.17254902]
   [0.49019608 0.69019608 0.15294118]]

  [[0.58431373 0.78823529 0.25098039]
   [0.58039216 0.78431373 0.25098039]
   [0.57647059 0.78039216 0.24313725]
   ...
   [0.60392157 0.80784314 0.24705882]
   [0.60392157 0.80392157 0.25098039]
   [0.56470588 0.76862745 0.22352941]]

  [[0.60784314 0.82352941 0.24313725]
   [0.6        0.81960784 0.23921569]
   [0.6        0.81568627 0.24705882]
   ...
   [0.61960784 0.82352941 0.25098039]
   [0.6        0.80392157 0.23921569]
   [0.55686275 0.76078431 0.2       ]]

  ...

  [[0.45882353 0.69803922 0.09803922]
   [0.45882353 0.69803922 0.09803922]
   [0.45882353 0.69803922 0.09803922]
   ...
   [0.61176471 0.85490196 0.17647059]
   [0.61176471 0.83921569 0.17647059]
   [0.59607843 0.82745098 0.16862745]]

  [[0.45882353 0.69803922 0.09803922]
   [0.45882353 0.69803922 0.09803922]
   [0.45882353 0.69803922 0.09803922]
   ...
   [0.61176471 0.85098039 0.18431373]
   [0.61176471 0.84313725 0.18431373]
   [0.6        0.83137255 0.17254902]]

  [[0.45882353 0.69803922 0.09803922]
   [0.45882353 0.69803922 0.09803922]
   [0.45882353 0.69803922 0.09803922]
   ...
   [0.61176471 0.85490196 0.18431373]
   [0.61176471 0.84313725 0.18431373]
   [0.60392157 0.83529412 0.18431373]]]


 [[[0.24313725 0.27843137 0.2745098 ]
   [0.09411765 0.10980392 0.10588235]
   [0.08235294 0.08235294 0.0745098 ]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.25098039 0.28235294 0.29019608]
   [0.10196078 0.11764706 0.12156863]
   [0.08627451 0.08627451 0.08627451]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.24705882 0.27843137 0.29019608]
   [0.09803922 0.11372549 0.1254902 ]
   [0.07843137 0.07843137 0.07843137]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  ...

  [[0.05490196 0.05490196 0.05490196]
   [0.05490196 0.05490196 0.05490196]
   [0.05490196 0.05490196 0.05490196]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.05490196 0.05490196 0.05490196]
   [0.05490196 0.05490196 0.05490196]
   [0.05490196 0.05490196 0.05490196]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.05882353 0.05882353 0.05882353]
   [0.05490196 0.05490196 0.05490196]
   [0.05490196 0.05490196 0.05490196]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]]


 [[[0.68627451 0.66666667 0.65490196]
   [0.68235294 0.66666667 0.65490196]
   [0.6627451  0.64705882 0.63529412]
   ...
   [0.50980392 0.49019608 0.4745098 ]
   [0.49803922 0.47843137 0.4627451 ]
   [0.48235294 0.47058824 0.45098039]]

  [[0.6745098  0.65490196 0.64313725]
   [0.6745098  0.65490196 0.64313725]
   [0.67843137 0.65882353 0.64705882]
   ...
   [0.52156863 0.50196078 0.48627451]
   [0.50588235 0.48627451 0.47058824]
   [0.49411765 0.4745098  0.45882353]]

  [[0.66666667 0.64313725 0.64313725]
   [0.65098039 0.62745098 0.62745098]
   [0.6627451  0.63921569 0.63921569]
   ...
   [0.54901961 0.51764706 0.50588235]
   [0.51764706 0.49803922 0.48235294]
   [0.49803922 0.47843137 0.4627451 ]]

  ...

  [[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]]]
gt_boxes [[[0.7390625 0.6171875 0.7984375 0.6609375]
  [0.31875   0.3671875 0.4125    0.64375  ]
  [0.        0.7796875 0.5296875 0.9453125]
  ...
  [0.        0.        0.        0.       ]
  [0.        0.        0.        0.       ]
  [0.        0.        0.        0.       ]]

 [[0.425     0.3125    0.6609375 0.7484375]
  [0.2828125 0.134375  0.325     0.2484375]
  [0.271875  0.        0.6796875 0.34375  ]
  ...
  [0.        0.        0.        0.       ]
  [0.        0.        0.        0.       ]
  [0.        0.        0.        0.       ]]

 [[0.2015625 0.31875   0.809375  0.9015625]
  [0.0953125 0.08125   0.9140625 0.71875  ]
  [0.4703125 0.1140625 0.5140625 0.15625  ]
  ...
  [0.        0.        0.        0.       ]
  [0.        0.        0.        0.       ]
  [0.        0.        0.        0.       ]]

 [[0.175     0.240625  0.7484375 0.9890625]
  [0.        0.        0.        0.       ]
  [0.        0.        0.        0.       ]
  ...
  [0.        0.        0.        0.       ]
  [0.        0.        0.        0.       ]
  [0.        0.        0.        0.       ]]

 [[0.3125    0.1390625 0.9375    0.53125  ]
  [0.146875  0.        0.521875  0.3296875]
  [0.        0.        0.        0.       ]
  ...
  [0.        0.        0.        0.       ]
  [0.        0.        0.        0.       ]
  [0.        0.        0.        0.       ]]]
gt_classes [[18  1 15  2  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0]
 [18 44 70  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0]
 [18  4 47 47  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0]
 [18  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0]
 [18  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0]]

這裏我們來看一下transform_targets方法,該方法是將COCO數據集中ground truth的標記值同9種anchor計算IoU(交併比)來得出最符合的一種box的座標來作爲真實box標籤,最後輸出的是該box的座標值以及置信度1和該box的分類標籤。由此我們知道YOLO並不是以圖像中的標記值直接作爲標籤的,而是以每個特徵層劃分的3種不同尺寸的單元格爲基本單位來作爲標籤值的。

def transform_targets(gt_boxes, gt_lables, anchors, anchor_masks, im_size):
    """ 計算YOLO訓練目標值
    :param gt_boxes: [batch, num_boxes, (x1, y1, x2, y2)]
    :param gt_lables: [batch, num_boxes]
    :param anchors: [(10, 13), (16, 30), (33, 23), (30, 61), (62, 45),
                    (59, 119), (116, 90), (156, 198), (373, 326)] / im_size
    :param anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
    :param im_size:
    :return:  ([N, grid, grid, anchors, [x1, y1, x2, y2, obj, class]], [], [])
    """
    y_outs = []
    # 原圖的1/32
    grid_size = im_size // 32

    # 計算9種anchor的面積, 這裏anchor都已經歸一化
    anchors = np.array(anchors, np.float32)
    anchor_area = anchors[..., 0] * anchors[..., 1]
    print('anchor_area', anchor_area)

    # 計算gt_box的寬高, 這裏寬高也已經歸一化
    box_wh = gt_boxes[..., 2:4] - gt_boxes[..., 0:2]
    box_wh = np.tile(np.expand_dims(box_wh, axis=-2),
                     (1, 1, np.shape(anchors)[0], 1))
    box_area = box_wh[..., 0] * box_wh[..., 1]
    print('box_area', box_area)

    # 計算iou
    intersection = np.minimum(box_wh[..., 0], anchors[..., 0]) * np.minimum(box_wh[..., 1], anchors[..., 1])
    iou = intersection / (box_area + anchor_area - intersection)
    print('iou', iou)
    print('ioushape', iou.shape)
    # 獲取ground truth與9種anchor的最大iou的索引
    anchor_idx = np.array(np.argmax(iou, axis=-1), np.float32)
    anchor_idx = np.expand_dims(anchor_idx, axis=-1)
    gt_labels = np.expand_dims(gt_lables, axis=-1)

    # 拼接最後的結果
    y_train = np.concatenate([gt_boxes, gt_labels, anchor_idx], axis=-1)
    # print(y_train)
    # 對於每一個特徵層(一共三層)計算一次最終目標值
    for anchor_idxs in anchor_masks:
        y_outs.append(transform_targets_for_output(y_train, grid_size, anchor_idxs))
        grid_size *= 2

    return tuple(y_outs)

def transform_targets_for_output(y_true, grid_size, anchor_idxs):
    """ 生成YOLO某一層output的目標值
    :param y_true: [N, boxes, (x1, y1, x2, y2, class, best_anchor)]
    :param grid_size:
    :param anchor_idxs: [,,]
    :return: y_true_out: [N, grid, grid, anchors, [x1, y1, x2, y2, obj, class]]
    """
    # y_true: [N, boxes, (x1, y1, x2, y2, class, best_anchor)]
    print('y_true', y_true)
    N, num_boxes, _ = np.shape(y_true)

    # y_true_out: [N, grid, grid, anchors, [x1, y1, x2, y2, obj, class]]
    y_true_out = np.zeros((N, grid_size, grid_size, np.shape(anchor_idxs)[0], 6), dtype=np.float32)

    anchor_idxs = np.array(anchor_idxs, np.int32)
    # indexes = tf.TensorArray(tf.int32, 1, dynamic_size=True)
    # updates = tf.TensorArray(tf.float32, 1, dynamic_size=True)
    for i in np.arange(N):
        for j in np.arange(num_boxes):
            # 這裏如果是padding的數據則跳過
            if y_true[i][j][2] == 0:
                continue
            # print(y_true[i][j][5])
            # 判斷跟傳進來的anchor idx哪個一樣, y_true[i][j][5]爲9個best anchor中的某一個
            anchor_eq = anchor_idxs == y_true[i][j][5]
            # print(anchor_eq)

            # 存在一個一樣
            if np.any(anchor_eq):
                box = y_true[i][j][0:4]
                # 計算中心點
                box_xy = (y_true[i][j][0:2] + y_true[i][j][2:4]) / 2
                anchor_idx = np.array(np.where(anchor_eq)[0], np.int32)
                grid_xy = np.array(box_xy // (1 / grid_size), np.int32)

                y_true_out[i, grid_xy[1], grid_xy[0], anchor_idx[0], :] = \
                    [box[0], box[1], box[2], box[3], 1, y_true[i, j, 4]]
                # print([box[0], box[1], box[2], box[3], 1, y_true[i,j,4]])
                # grid[y][x][anchor] = (tx, ty, bw, bh, obj, class)
                # indexes = indexes.write(
                #     idx, [i, grid_xy[1], grid_xy[0], anchor_idx[0][0]])
                # updates = updates.write(
                #     idx, [box[0], box[1], box[2], box[3], 1, y_true[i][j][4]])

    # tf.print(indexes.stack())
    # tf.print(updates.stack())
    return y_true_out

運行結果

anchor_area [0.00083008 0.00545898 0.02046387 0.02223633 0.06881836 0.07457275
 0.2064917  0.21621095 0.5328125 ]
box_area [[[0.00259766 0.00259766 0.00259766 ... 0.00259766 0.00259766 0.00259766]
  [0.02592773 0.02592773 0.02592773 ... 0.02592773 0.02592773 0.02592773]
  [0.08772949 0.08772949 0.08772949 ... 0.08772949 0.08772949 0.08772949]
  ...
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]]

 [[0.102854   0.102854   0.102854   ... 0.102854   0.102854   0.102854  ]
  [0.00481201 0.00481201 0.00481201 ... 0.00481201 0.00481201 0.00481201]
  [0.14018555 0.14018555 0.14018555 ... 0.14018555 0.14018555 0.14018555]
  ...
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]]

 [[0.35424072 0.35424072 0.35424072 ... 0.35424072 0.35424072 0.35424072]
  [0.52195312 0.52195312 0.52195312 ... 0.52195312 0.52195312 0.52195312]
  [0.0018457  0.0018457  0.0018457  ... 0.0018457  0.0018457  0.0018457 ]
  ...
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]]

 [[0.42918213 0.42918213 0.42918213 ... 0.42918213 0.42918213 0.42918213]
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]
  ...
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]]

 [[0.24511719 0.24511719 0.24511719 ... 0.24511719 0.24511719 0.24511719]
  [0.12363281 0.12363281 0.12363281 ... 0.12363281 0.12363281 0.12363281]
  [0.         0.         0.         ... 0.         0.         0.        ]
  ...
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]]]
iou [[[0.31954888 0.47584972 0.12693868 ... 0.01257996 0.01201445 0.00487537]
  [0.03201507 0.21054614 0.66947811 ... 0.12556309 0.11991869 0.04866203]
  [0.00946179 0.06222519 0.1874598  ... 0.25776757 0.40575879 0.1646536 ]
  ...
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]]

 [[0.00807045 0.05307508 0.19896033 ... 0.49810238 0.33256859 0.19303978]
  [0.17250127 0.50089186 0.23514674 ... 0.02330366 0.0222561  0.00903134]
  [0.00592128 0.03894114 0.14597701 ... 0.56491694 0.62916833 0.26310485]
  ...
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]]

 [[0.00234326 0.01541038 0.05776825 ... 0.58291346 0.56153389 0.66485065]
  [0.00159033 0.01045876 0.03920633 ... 0.39561347 0.4142344  0.8811481 ]
  [0.44973546 0.33810375 0.09019327 ... 0.00893839 0.00853659 0.00346408]
  ...
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]]

 [[0.00193409 0.01271951 0.04768108 ... 0.48112834 0.42830977 0.69437937]
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]
  ...
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]]

 [[0.00338645 0.02227092 0.08348605 ... 0.46233081 0.84243369 0.46004401]
  [0.00671406 0.04415482 0.16552132 ... 0.57129077 0.57181569 0.23203813]
  [0.         0.         0.         ... 0.         0.         0.        ]
  ...
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]]]
ioushape (5, 100, 9)
y_true [[[ 0.7390625  0.6171875  0.7984375  0.6609375 18.         1.       ]
  [ 0.31875    0.3671875  0.4125     0.64375    1.         2.       ]
  [ 0.         0.7796875  0.5296875  0.9453125 15.         5.       ]
  ...
  [ 0.         0.         0.         0.         0.         0.       ]
  [ 0.         0.         0.         0.         0.         0.       ]
  [ 0.         0.         0.         0.         0.         0.       ]]

 [[ 0.425      0.3125     0.6609375  0.7484375 18.         4.       ]
  [ 0.2828125  0.134375   0.325      0.2484375 44.         1.       ]
  [ 0.271875   0.         0.6796875  0.34375   70.         7.       ]
  ...
  [ 0.         0.         0.         0.         0.         0.       ]
  [ 0.         0.         0.         0.         0.         0.       ]
  [ 0.         0.         0.         0.         0.         0.       ]]

 [[ 0.2015625  0.31875    0.809375   0.9015625 18.         8.       ]
  [ 0.0953125  0.08125    0.9140625  0.71875    4.         8.       ]
  [ 0.4703125  0.1140625  0.5140625  0.15625   47.         0.       ]
  ...
  [ 0.         0.         0.         0.         0.         0.       ]
  [ 0.         0.         0.         0.         0.         0.       ]
  [ 0.         0.         0.         0.         0.         0.       ]]

 [[ 0.175      0.240625   0.7484375  0.9890625 18.         8.       ]
  [ 0.         0.         0.         0.         0.         0.       ]
  [ 0.         0.         0.         0.         0.         0.       ]
  ...
  [ 0.         0.         0.         0.         0.         0.       ]
  [ 0.         0.         0.         0.         0.         0.       ]
  [ 0.         0.         0.         0.         0.         0.       ]]

 [[ 0.3125     0.1390625  0.9375     0.53125   18.         7.       ]
  [ 0.146875   0.         0.521875   0.3296875  1.         5.       ]
  [ 0.         0.         0.         0.         0.         0.       ]
  ...
  [ 0.         0.         0.         0.         0.         0.       ]
  [ 0.         0.         0.         0.         0.         0.       ]
  [ 0.         0.         0.         0.         0.         0.       ]]]

然後是構建損失函數和梯度下降

# 3層輸出分別計算損失
total_xy_loss = total_wh_loss = total_obj_loss = total_class_loss = 0.
for i in range(3):
    # 獲取目標標籤邊框,置信度,分類
    true_box, true_obj, true_class = np.split(yolo_targets[i], (4, 5), axis=-1)
    # 獲取預測邊框,置信度,分類和邊框偏移量
    pred_box, pred_obj, pred_class, pred_box_xywh = yolo_preds[i]

    xy_loss, wh_loss, obj_loss, class_loss = loss(
        pred_box=pred_box,
        pred_box_xywh=pred_box_xywh,
        true_box=true_box,
        pred_obj=pred_obj,
        true_obj=true_obj,
        pred_class=pred_class,
        true_class=true_class,
        anchors=anchors[anchor_masks[i]],
        ignore_thresh=0.5
    )
    # print(i, tf.reduce_mean(xy_loss),  tf.reduce_mean(obj_loss))

    total_xy_loss += tf.reduce_mean(xy_loss)
    total_wh_loss += tf.reduce_mean(wh_loss)
    total_obj_loss += tf.reduce_mean(obj_loss)
    total_class_loss += tf.reduce_mean(class_loss)

total_loss = total_xy_loss + total_wh_loss + total_obj_loss + total_class_loss
grad = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(grad, model.trainable_variables))

這裏我們來看一下loss函數

def loss(pred_box, pred_box_xywh, true_box, pred_obj, true_obj, pred_class, true_class, anchors, ignore_thresh,
         balanced_rate=5):
    # def loss(preds, targets, anchors, ignore_thresh)
    """
    :param pred_box: [batch_size, grid, grid, anchors, (x1, y1, x2, y2)]
    :param pred_box_xywh: [batch_size, grid, grid, anchors, (tx, ty, tw, th)]
    :param true_box: [batch_size, grid, grid, anchors, (x1, y1, x2, y2)]
    :param pred_obj: [batch_size, grid, grid, anchors, 1]
    :param true_obj: [batch_size, grid, grid, anchors, 1]
    :param pred_class: [batch_size, grid, grid, anchors, num_classes]
    :param true_class: [batch_size, grid, grid, anchors, 1]
    :param anchors: [[w1,h1],[w2,h2],[w3,h3]]
    :param ignore_thresh: 正負樣本iou閾值
    :param balanced_rate: 正負樣本平衡比例
    :return:
    """
    # [batch_size, grid, grid, anchors, 2]
    # 獲取座標偏移量
    pred_xy = pred_box_xywh[..., 0:2]
    # [batch_size, grid, grid, anchors, 2]
    # 獲取座標偏移量的寬高
    pred_wh = pred_box_xywh[..., 2:4]

    # true_box, true_obj, true_class_idx = tf.split(true_box, (4, 1, 1), axis=-1)
    # 獲取目標邊框中心點座標
    true_xy = (true_box[..., 0:2] + true_box[..., 2:4]) / 2
    # 獲取目標邊框的寬高
    true_wh = true_box[..., 2:4] - true_box[..., 0:2]

    # 小目標檢測常數
    box_loss_scale = 2 - true_wh[..., 0] * true_wh[..., 1]

    # 3. inverting the pred box equations
    grid_size = tf.shape(true_box)[1]
    grid = tf.meshgrid(tf.range(grid_size), tf.range(grid_size))
    # [grid_size, grid_size, 1, 2]
    grid = tf.expand_dims(tf.stack(grid, axis=-1), axis=2)
    # 計算true_box的平移縮放量
    # [batch_size, grid, grid, anchors, 2]
    # 獲取目標邊框偏移量
    true_xy = true_xy * tf.cast(grid_size, tf.float32) - tf.cast(grid, tf.float32)
    # [batch_size, grid, grid, anchors, 2]

    true_wh = tf.math.log(true_wh / anchors)
    true_wh = tf.where(tf.math.is_inf(true_wh), tf.zeros_like(true_wh), true_wh)

    # 4. calculate all masks
    # [batch_size, grid, grid, anchors]
    obj_mask = tf.squeeze(true_obj, -1)
    # 構建正負樣本的數量
    positive_num = tf.cast(tf.reduce_sum(obj_mask), tf.int32) + 1
    negative_num = balanced_rate * positive_num
    # ignore false positive when iou is over threshold
    # [batch_size, grid, grid, anchors, num_gt_box] => [batch_size, grid, grid, anchors, 1]
    # 對預測邊框和目標邊框計算IoU
    best_iou = tf.map_fn(
        lambda x: tf.reduce_max(broadcast_iou(x[0], tf.boolean_mask(
            x[1], tf.cast(x[2], tf.bool))), axis=-1),
        (pred_box, true_box, obj_mask),
        tf.float32)
    # [batch_size, grid, grid, anchors, 1]
    ignore_mask = tf.cast(best_iou < ignore_thresh, tf.float32)
    # 這裏做了下樣本均衡.
    ignore_num = tf.cast(tf.reduce_sum(ignore_mask), tf.int32)
    if ignore_num > negative_num:
        neg_inds = tf.random.shuffle(tf.where(ignore_mask))[:negative_num]
        neg_inds = tf.expand_dims(neg_inds, axis=1)
        ones = tf.ones(tf.shape(neg_inds)[0], tf.float32)
        ones = tf.expand_dims(ones, axis=1)
        # 更新mask
        ignore_mask = tf.zeros_like(ignore_mask, tf.float32)
        ignore_mask = tf.tensor_scatter_nd_add(ignore_mask, neg_inds, ones)

    # 5. calculate all losses
    # [batch_size, grid, grid, anchors]
    # 構建邊框偏移量損失
    xy_loss = obj_mask * box_loss_scale * tf.reduce_sum(tf.square(true_xy - pred_xy), axis=-1)
    # [batch_size, grid, grid, anchors]
    # 構建邊框寬高損失
    wh_loss = obj_mask * box_loss_scale * tf.reduce_sum(tf.square(true_wh - pred_wh), axis=-1)

    # obj_loss = binary_crossentropy(true_obj, pred_obj)
    conf_focal = tf.pow(obj_mask - tf.squeeze(pred_obj, -1), 2)
    # 構建置信度損失
    obj_loss = losses.binary_crossentropy(true_obj, pred_obj)
    obj_loss = conf_focal * (obj_mask * obj_loss + (1 - obj_mask) * ignore_mask * obj_loss)

    # obj_loss = tf.keras.losses.binary_crossentropy(true_obj, pred_obj)
    # 這裏除了正樣本會計算損失, 負樣本低於一定置信的也計算損失
    # obj_loss = obj_mask * obj_loss + (1 - obj_mask) * ignore_mask * obj_loss

    # TODO: use binary_crossentropy instead
    # class_loss = obj_mask * sparse_categorical_crossentropy(true_class_idx, pred_class)
    # 構建分類損失
    class_loss = obj_mask * losses.sparse_categorical_crossentropy(true_class, pred_class)

    # 6. sum over (batch, gridx, gridy, anchors) => (batch, 1)
    xy_loss = tf.reduce_sum(xy_loss, axis=(1, 2, 3))
    wh_loss = tf.reduce_sum(wh_loss, axis=(1, 2, 3))
    obj_loss = tf.reduce_sum(obj_loss, axis=(1, 2, 3))
    class_loss = tf.reduce_sum(class_loss, axis=(1, 2, 3))

    # return xy_loss + wh_loss + obj_loss + class_loss
    return xy_loss, wh_loss, obj_loss, class_loss

def broadcast_iou(box_1, box_2):
    """ 計算最終iou

    :param box_1:
    :param box_2:
    :return: [batch_size, grid, grid, anchors, num_gt_box]
    """
    # box_1: (..., (x1, y1, x2, y2))
    # box_2: (N, (x1, y1, x2, y2))

    # broadcast boxes
    box_1 = tf.expand_dims(box_1, -2)
    box_2 = tf.expand_dims(box_2, 0)
    # new_shape: (..., N, (x1, y1, x2, y2))
    new_shape = tf.broadcast_dynamic_shape(tf.shape(box_1), tf.shape(box_2))
    box_1 = tf.broadcast_to(box_1, new_shape)
    box_2 = tf.broadcast_to(box_2, new_shape)

    int_w = tf.maximum(tf.minimum(box_1[..., 2], box_2[..., 2]) -
                       tf.maximum(box_1[..., 0], box_2[..., 0]), 0)
    int_h = tf.maximum(tf.minimum(box_1[..., 3], box_2[..., 3]) -
                       tf.maximum(box_1[..., 1], box_2[..., 1]), 0)
    int_area = int_w * int_h
    box_1_area = (box_1[..., 2] - box_1[..., 0]) * \
                 (box_1[..., 3] - box_1[..., 1])
    box_2_area = (box_2[..., 2] - box_2[..., 0]) * \
                 (box_2[..., 3] - box_2[..., 1])
    return int_area / (box_1_area + box_2_area - int_area)

最後是記錄訓練日誌

# Scalar
with summary_writer.as_default():
    tf.summary.scalar('loss/xy_loss', total_xy_loss,
                      step=epoch * train_data.total_batch_size + batch)
    # step=step)
    tf.summary.scalar('loss/wh_loss', total_wh_loss,
                      step=epoch * train_data.total_batch_size + batch)
    # step=step)
    tf.summary.scalar('loss/obj_loss', total_obj_loss,
                      step=epoch * train_data.total_batch_size + batch)
    # step=step)
    tf.summary.scalar('loss/class_loss', total_class_loss,
                      step=epoch * train_data.total_batch_size + batch)
    # step=step)
    tf.summary.scalar('loss/total_loss', total_loss,
                      step=epoch * train_data.total_batch_size + batch)
    # step=step)

# image, 只拿每個batch的第一張
# gt
gt_img = gt_imgs[0].copy() * 255
gt_boxes = gt_boxes[0] * image_shape[0]
gt_classes = gt_classes[0]
non_zero_ids = np.where(np.sum(gt_boxes, axis=-1))[0]
for i in non_zero_ids:
    label = gt_classes[i]
    class_name = classes[label]['name']
    xmin, ymin, xmax, ymax = gt_boxes[i]
    gt_img = draw_bounding_box(gt_img, class_name, label, int(xmin), int(ymin), int(xmax),
                               int(ymax))

# pred
pred_img = gt_imgs[0].copy() * 255
boxes, scores, classes, valid_detection_nums = yolo_nms(yolo_preds, 91)
# print(scores)
# print(gt_classes)
for i in range(valid_detection_nums[0]):
    if scores[0][i] > 0.5:
        label = classes[0][i].numpy()
        if classes.get(label):
            class_name = classes[label]['name']
            xmin, ymin, xmax, ymax = boxes[0][i] * image_shape[0]
            pred_img = draw_bounding_box(pred_img, class_name, scores[0][i], int(xmin), int(ymin),
                                         int(xmax), int(ymax))

concat_imgs = tf.concat([gt_img[:, :, ::-1], pred_img[:, :, ::-1]], axis=1)
summ_imgs = tf.expand_dims(concat_imgs, 0)
summ_imgs = tf.cast(summ_imgs, dtype=tf.uint8)
with summary_writer.as_default():
    tf.summary.image("imgs/gt,pred,epoch{}".format(epoch), summ_imgs,
                     step=epoch * train_data.total_batch_size + batch)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章