YOLO系列介绍

YOLOV1

YOLO(You Only Look Once: Unified, Real-Time Object Detection)是Joseph Redmon和Ali Farhadi等于2015年首次提出,在2017年CVPR上,Joseph Redmon和Ali Farhadi又提出的YOLOV2,后又再次提出YOLOV3,它是一个标准的One-stage目标检测算法。

相对于Faster RCNN系列和SSD系列,它能够更好的贯彻采用直接回归的方法获取到当前需要检测的目标以及目标类别问题的思想。YOLO算法的核心点在于输入图像,采用同时预测多个Bounding box的位置和类别的方法来检测位置和类别的分类。它是一种更加彻底的端到端的目标检测识别的方法。相比于Faster RCNN和SSD而言,能够达到一个更快的检测速度,但是相对于这两个系列的算法而言,YOLO算法整体检测的精度会低一些。

YOLO算法采用直接回归功能的CNN来完成整个目标检测的过程。这个过程不需要额外,设计复杂的过程。SSD算法在对目标检测的过程中,一方面用到了Anchor机制,另一方面需要针对多个不同的feature map来进行预测。Faster RCNN算法需要通过一个RPN网络来获取到我们需要的候选Bounding box,再通过后续的预测模型来完成最终的预测结果。YOLO算法相比于这两种算法而言,没有Anchor机制,多尺度等等设计的过程。YOLO直接采用一个卷积网络,最终通过直接回归的方法,来获取多个Bounding box的位置以及类别。直接选用整图来进行模型的训练并且能够更好的区分目标和背景区域。

YOLOV1算法的核心思想就在于一方面它通过将图像分成S*S的格子。对于每一个格子(区域)去负责预测相应的物体。即包含GT物体中心的格子,由它来预测相应的物体。在上图中有一个带红点的格子,它包含了狗的一部分,因此我们就需要用这样一个格子来检测狗这种物体。又比如说包含了自行车的格子就用来预测自行车这种物体。

在实际使用的时候,每个格子都会预测B(超参值,通常为2)个检测框Bounding box及每个检测框的置信度(表明了当前的格子所包含物体的概率),对于每个检测框,要预测C个类别概率(含有待检测物体的可能性)。对于一个格子,它最终预测出来的向量的长度为5*B+C,这里5包含了每一个Bounding box的(x,y,w,h,c),c是置信度,而整个图所回归的向量长度为(5*B+C)*S*S。

Bounding box信息(x,y,w,h)为物体的中心位置相对格子位置的偏移及宽度和高度,这些值均被归一化。换句话说,实际上我们预测的值是一个相对的值。置信度反映是否包含物体以及包含物体情况下位置的准确性,定义为

其中

YOLOV1的网络结构图

YOLO的网络结构采用直接回归的方式来获取到图片中所需要检测到的图片的位置和类别。从上图中我们可以看到,它只包含了CNN的网络结构,从最开始的原始的输入图像,经过多层卷积之后,最终通过一个FC层,最终输出的向量为S*S*(B*5+C)的长度。对于YOLOV1来说,通常S会定义为7*7,B会定义为2,C定义为20,也就是20个类别。通过这样一个回归,我们最终就能够得到对于每一个格子,它所包含的Bounding box的置信度以及是否包含待检测物体。如果包含待检测物体,当前这样一个物体,它的偏移量是多少以及它的长宽是多少,并且我们能够得到相应概率的分布。

通过对FC输出向量的解码,解码之后就能够得到

这样两幅图,对于第一幅图虽然每个格子预测了很多Bounding box,但是我们只选择IoU最高的Bounding box作为物体检测的输出。也就是说对于每个格子,我们最终只预测当中的一个物体,实际上这也是YOLO算法的一个缺陷。图像中可能会包含多个小的目标,此时由于每个格子只预测一个物体,如果一个格子同时出现多个物体的时候,对于小目标预测效果就会变得非常差。这也是YOLOV1算法主要的一个缺陷。通过NMS来进行Bounding box的合并、筛选和过滤之后,我们就能得到最终的检测结果。

YOLO算法强调网络使用小卷积,即:1*1和3*3(GoogleNet),能够一方面减少计算量,另一方面减少模型的大小。网络相比VGG16而言,速度会更快,但准确度稍差。

  • YOLOV1损失函数

包含了三种Loss,座标误差 、IoU误差和分类误差。这里每一种Loss也对应到了每一个网格所预测的信息对应的三种Loss。其中座标误差对应到了S*S*(B*5+C)中的B,也就是Bounding box预测的信息之间的偏差。IoU的Loss对应到了座标的误差。分类误差对应到了当前的格子所包含的groud true(gt,物体类别)所产生的误差。

通过对这三种误差的结合,最终通过加权方式进行权重考量来得到最终的loss,通过均方和误差的方式来进行最终的考量用于后续的网络模型的训练。

  • YOLOV1网络训练

对于YOLOV1的具体使用的时候会用到下面一些技巧

首先在网络进行目标检测的时候,会采用预训练模型来对模型的参数进行初步的训练,对模型参数进行初始化。这里采用ImageNet 1000类来对模型进行预训练。对于分类任务和回归任务而言,会存在最后几重的差异。对于分类任务ImageNet 1000类的FC层的输出应该是1000,而YOLOV1的FC层最终输出为S*S*(B*5+C)这样一个值,因此我们在使用预训练模型的时候会去掉后面的几个FC层。

这里实际上采用了预训练模型前20个卷积层,并且用这前20个卷积层来初始化YOLO,用于后续的目标检测任务的训练,如VOC20数据集。由于ImageNet数据的输入图像为224*224,在YOLOV1中会将图像resize到448*448。对于预训练模型,如果我们仅仅是使用了它的卷积层,而卷积层对于feature map的大小实际上是不敏感的,它仅仅关注卷积核的参数(大小和通道数)。但是如果我们复用了FC层,FC层的参数量和我们输入的图像或者feature map的大小是相关的,如果图像的大小发生了变化,会影响到FC层的输入,此时FC层就没办法采用预训练模型来进行训练了。这里由于我们在YOLO预训练的时候只采用了前20个卷积层,去掉了FC层,此时就可以改变图像的大小,并且能够保证预训练模型能继续使用。

在训练B个Bounding box的时候,它的GT(真值)的设置是相同的。

  • YOLOV1网络存在的问题

相对于SSD算法和Faster RCNN算法的效果有一定的差距,

  1. 在进行YOLO最后检测的时候,输入尺寸固定,没有采用多尺度的特征的输入。这是相对SSD算法对6个尺度来进行Prio box的提取以及最终的预测。而YOLO算法是一个完整的卷积网络,没有提取多尺度的feature map。因此YOLOV1算法在特征提取的时候通过多个下采样层学到的最终物体的特征并不精细,因此也会影响到检测的效果。
  2. YOLOV1在进行小目标检测的时候效果差。在同一个格子中包含多个目标时,仅预测一个目标(IoU最高),会忽略掉其他目标,此时就必然会有漏检的情况产生。
  3. 在YOLOV1的损失函数中关于IoU的loss,实际上并没有去区分大物体的IoU和小物体IoU的误差对于网络训练loss贡献值的影响。这里它们的贡献值基本上是接近的,实际上对于小物体而言,小物体的IoU的误差会对网络优化造成更大的影响,进而降低物体检测定位的准确性。因此YOLOV1算法在loss设计上也没有特别多的技巧,这也是后续YOLO算法的改进点。
  4. 如果同一个物体出现新的不常见的长宽比和一些其他情况的时候,YOLOV1算法的泛化能力也较差。
  • YOLOV1网络性能

上图是不同尺度训练的精度与其他网络的精度对比,我们不难发现YOLOV1在相同的数据集中,他的mAP(精度)下降了很多。但是在检测速度上,如果只考虑相同尺度的条件下(448*448),YOLO算法能够达到45FPS,相对于Faster RCNN而言,检测速度是非常快的。相比于SSD500(即图像尺寸500*500)的速度也是非常快的,相比于SSD300(图像尺寸300*300)的速度是非常接近的。也就是说YOLOV1在较大尺寸上的图像检测速度能够保持跟SSD较小图像检测速度相同的检测速度。

YOLOV2

基于YOLOV1存在的问题,作者在2017年提出了YOLOV2的算法,并基于YOLOV2提出了YOLO9000这样两种模型。YOLOV2相对于YOLOV1改进的几个核心的点在于

  1. 引入了Anchor box的思想,改进直接回归这样一种比较粗糙的方式
  2. 在输出层使用卷积层替代YOLOV1的全连接层(FC层),能够带来一个比较直观的好处就是能够减少对于输入图像尺寸的敏感程度。因为FC层的参数量同图像大小是息息相关的,而卷积层同图像大小是不存在关联的。
  3. 对于YOLO9000而言,在最终训练的时候,实际上是采用了ImageNet物体分类以及coco物体检测这样的两种数据集来对模型进行训练。用检测中的数据集中的数据来学习物体的准确的位置信息。用分类数据集来学习分类的信息。通过这种多任务来提高最终网络的鲁棒性。
  4. 相比于YOLOV1而言,YOLOV2不仅在识别物体的种类上,以及精度、速度、和物体的定位上都得到了大大的提升。

YOLOV2算法成为了当时最具有代表性的目标检测算法的一种,YOLOV2/YOLO9000的改进之处:

在上图中,我们可以看到主干网络采用了DarkNet的网络结构,在YOLOV1算法中,作者采用了GoogleNet这样一种架构来作为主干网络,它的性能要优于VGGNet的。DarkNet类似于VGGNet,采用了小的3*3的卷积核,在每次池化之后,整个通道的数量都会增加一倍,并且在网络结构中采用Batch Normalization来进行归一化处理,进而让整个训练过程变得更加的稳定,收敛速度变得更快,达到模型规范化的效果。

由于使用卷积层来代替FC层,因此输入的图像尺寸就可以发生变化,因而整个网络的参数同feature map的大小是无关的。因此我们可以改变图像的尺寸来进行多尺度的训练。对于分类模型采用了高分辨率的分类器。

YOLOV1算法只采用了一个维度上的特征,因此它学到的特征因此相对来说不会太精细,而YOLOV2采用了一个跳连的结构,换句话说在最终的预测的阶段,实际上采用了不同粒度上的特征,通过对不同粒度上特征的融合,来提高最终检测的性能。在最终预测的时候同样采用了Anchor的机制,Anchor机制也是Faster RCNN或者SSD算法一个非常核心重要的元素,这个元素能够带来模型在性能上的提升。

  • Batch Normalization
  1. V1中也大量用了BN,但是在定位层FC层采用了dropout来防止过拟合。
  2. V2中取消了dropout,在整个网络结构中均采用BN来进行模型的规范化,模型更加稳定,收敛速度更快,
  • 高分辨率分类器
  1. V1中使用224*224的预训练模型,但是实际上采用了448*448的图像来用于网络检测。这个过程实际上会存在一定的偏差,必然带来分布上的差异,
  2. V2直接采用448*448的分辨率微调最初的分类网络。保证了分类和检测这样的两个模型在分布上的一致性。
  • Anchor Boxes
  1. 在预测Bounding box的偏移,使用卷积代替FC。我们知道在V1中FC层输出的向量的大小为S*S*(B*5+C),而V2中直接采用卷积来代替的话,卷积之后的feature map的大小为S*S,(B*5+C)则对应了通道的数量,此时同样能够达到V1的FC层相同的效果。
  2. 在V2中输入的图像尺寸为416*416,而不是448*448,主要原因就在于图片中的物体倾向于出现在图片的中心位置,特别是比较大的物体,因此就需要有一个单独位于物体中心的位置用来预测这个物体。而YOLO通过卷积层之后,会进行32倍的下采样。对于416*416的图像,下采样32倍之后就会得到一个13*13的feature map。对于448*448的图像进行32倍下采样之后得到一个14*14的feature map,此时就不存在这样一个网格位于图像的正中心。为了保证feature map必然会存在一个网格位于图像的正中心,此时我们只需要将经过下采样之后的feature map的大小定义为13*13,就能保证一定会存在中心的一个Cell,能够预测位于中心的物体。因此我们要保证最终得到的feature map的大小为13*13,反推过来,进行32倍的上采样,就能够得到输入的尺寸为416。这也是为了后面产生的卷积图的宽高比为奇数,就能产生一个中心的Cell。主要原因是作者通过观察发现大物体通常占据图像的中间位置,就需要一个位于中间的Cell来预测位于中间的物体的位置。如果不采用奇数的长宽的话,就需要用到中间的4个Cell来预测中间的物体。通过奇数输出的技巧就能够提高总体的效率。
  3. V2加入了Anchor机制之后,对于每一个Cell,会预测多个建议框。相比于之前的网络仅仅预测B个(B通常为2)建议框而言,采用Anchor Box之后,结果的召回率得到了显著的提升。但是mAP却有了一点点的下降。在作者看来准确率只有小幅度的下降,但是召回率提高了非常大,这也反映了Anchor Box确实能够在一定程度上带来整个模型性能上的提升。当然我们也需要去进一步优化准确度下降的缺陷,在V2中采用了max pooling的方法来进行下采样。
  4. 加入了Anchor机制之后,整个Bounding box的预测数量超过了1000个。比如说经过下采样之后的feature map为13*13的话,每个Anchor需要预测9个Bounding box的话,那么整个feature map需要预测13*13*9=1521个Bounding box。相比于之前的7*7*2=98个而言,整体需要预测的框的数量就得到了提高。进而也会带来模型在性能上的提高。但是作者在使用Anchor Box之后也遇到了两个问题,一个是对于Anchor Box而言,它的宽高维度往往是精选先验框,虽然在训练的过程中网络也会调整Box的宽高维度,最终得到准确的Bounding box的位置,但是作者希望在最开始选择的时候就选择那些最具代表性的先验的Bounding box的维度,这样就能够通过网络更容易的学习到准确的预测的位置。因此作者采用了K-means方法来对Bounding box进行回归,自动找到那些更好的Bounding box的宽高维度比。在使用K-means方法来对Bounding box聚类的时候,同样作者也遇到了一个问题,就是传统的K-means方法来进行Bounding box聚类的时候,主要采用欧式距离的方法来度量两个Bounding box之间的相似度,这样也就意味着较大的Bounding box会比较小的Bounding box产生更多的误差,因此作者在训练的时候就采用了IoU得分作为距离的度量。此时所得到的误差也就和Bounding box的尺寸无关了。经过聚类之后,作者也确定了预测的Anchor Box数量为5的时候比较合适。作者最终选择了5种大小的Bounding box的维度来进行定位预测。在这样一个聚类的结果中,作者也发现扁长的框较少,而瘦高的框会较多一点,这实际上也符合了行人的特征。有关K-means聚类的内容可以参考聚类
  • 细粒度特征

在传统的V1算法中,直接采用了从顶到下的神经网的结构,没有考虑不同尺度下的特征。在V2中通过添加pass through layer,把浅层特征图(26*26)连接到深层特征图(13*13)。在连接的时候作者并没有采用Pooling的方法进行下采样,而是将26*26*512的特征图直接叠加成13*13*2048的特征图,然后再与深层特征图相连接,增加细粒度特征。将粗粒度与细粒度的融合,性能获得了1%的提升。这是类似于ResNet中的identity mapping的方法。

  • Multi-Scale Training

多尺度训练,每隔几次迭代后就会微调网络的输入尺寸。输入图像尺寸包括了多个不同的尺度{320,352,...,608},这里为什么针对不同的图像输入尺度采用同一种参数,主要原因就在于在整个V2结构中并没有采用FC层这种同feature map大小相关的网络层,整个网络通过卷积层的堆叠来完成,因此整个网络参数的数量同feature map的大小是不相关的。因此我们可以改变图像的尺寸来增加整个网络对于图像尺寸变化的鲁棒性。通过这样的机制就使得网络可以预测不同尺度的图片。这也意味着同一个网络可以进行不同分辨率的检测任务。在小尺寸图片上,V2能够实现更快的运行检测速度,并且在速度和精度上达到一个平衡。实际上如果输入的图像为228*228的话,它的FPS能够达到90,并且mAP值同Faster RCNN在同一个水准的。因此V2通常会用在低性能的GPU,高帧率的视频检测,多路视频场景中,也就是说在一些低功耗和视频图像处理中,YOLO算法会有更大的应用的范围,因为它的速度能够达到更高的实时性并且在精度上能够同一些其他的深度学习检测算法保持在相同的水准上。

  • Darknet-19

在V1中使用GoogleNet作为主干网络,在V2中作者重新设计了一种新的网络来作为特征提取部分。Darknet这种网络结构,作者也参考了一些前人的先进经验,它的整个网络结构也类似于VGGNet,作者使用了较多的3*3的卷积核来进行堆叠。在一次Max Pooling操作后,通道数量进行了翻倍。另外作者借鉴了Net in Net的方法,使用了global average pooling,将1*1的卷积核置于3*3的卷积核之间,用来压缩特征。并且在网络结构设计的时候,作者采用了batch normalization来对模型进行规范化。通过batch normalization一方面加快了模型的训练速度,另一方面提高了模型训练的稳定性。整个Darknet的网络结构包括了19个卷积层以及5个池化层,整个运算的次数为55.8亿次。相比于VGGNet而言,它的整个计算量也有了一定的下降。在ImageNet图像分类任务上,在top 1准确度上也达到了72.9%的准确度。当然我们在使用YOLOV2的时候,同样也可以采用更加先进的网络结构,如ResNet、DenseNet等等一些其他的主干网络结构。或者说更加轻量级的网络结构,如MobileNet等。具体采用什么网络结构可以经过不断的尝试来对比不同的网络结构对于YOLOV2算法性能的影响。

对于YOLOV2而言,我们在预测Anchor box,对于每一个Bounding box,同样会预测4个座标值,1个置信度值和C个类别上的概率分布值。这一点也是同V1存在区别的,对于V1而言,这里的类别值是针对于一个Cell(格子)而言的,每一个格子又对应了B个Bounding box,最终预测出来的向量为(5*B+C)。而在V2中,类别的概率分布实际上是对于每一个Bounding box而言的,这一点也是同Anchor box保持一致的,对于每一个Bounding box会预测出(5+C)长度的向量,整个Anchor box假设是B个Bounding box,那整个Anchor box预测出来的向量就是B*(5+C)个。这是跟V1相区别的一点,在类别的预测上更加关注于每一个Bounding box,主要原因就在于这里采用了Anchor机制。而LOYOV1的类别主要是针对于每一个Cell,也就是说对于每一个Cell,只预测一个类别的物体。

  • YOLOV2算法网络性能

通过YOLOV2对比YOLOV1几点的改进上来看,我们会发现作者在进行改进的时候,每一点的的加入都会带来性能上的提升。但是有一点下降的时候就是如上图所示的加入了Anchor Box的时候性能有了一点点的下降,从69.5降低到了69.2。但是这一点点下降带来的是召回率的较大程度上的提升。经过了后面跳连、多尺度的加入后,YOLOV2在整体上相对于V1有了一个非常大的提升,从63.4提升到了78.6。

这里我们可以看一下,相比于SSD和Faster RCNN算法而言,YOLOV2算法能够达到一个更好的检测精度,并且能够实现更快的检测速度,因此YOLOV2也成为了当时最先进的深度学习目标检测算法。

同样我们也可以看到,上图是关于mAP和FPS整体的一个曲线图,YOLOV2它能达到更好的一个效果,在保证较快的检测速度的同时,能够保证较好的检测精度。

YOLO9000

YOLO9000是在YOLOV2的基础上提出的一种可以检测超过9000个类别的模型,其主要贡献点在于提出了一种分类和检测的联合训练策略。

这主要归功于它采用了WordTree这样一种结构。通过WordTree来混合检测数据集与识别数据集中的数据,来达到检测和分类联合训练的效果。这种联合技术分别在ImageNet和COCO数据集上进行训练。对于分类任务,它的标签粒度实际上是更细的。换句话说,对于分类任务而言,同样是狗,对于数据集中的label,它可能就包括了更加细的狗的类别的划分,比如说包括了哈士奇、金毛等更细粒度的标签。而对于检测任务而言,它仅仅是区别猫、狗这样一种相对来说粗的粒度上的概念。如果将分类和回归采用简单的方法磨合,就会同时存在狗这样的label和哈士奇这样的label的情况。而WordTree则是将这两种label来构建它们之间的粒度关系,将整个分类和检测任务的数据集来进行融合。在检测数据集中,我们不仅需要完成物体类别的回归,同样我们需要对物体的类别进行判定;而在分类数据集上,我们需要对物体的类别进行分类,但是物体类别的粒度会更细。通过WordTree就能够将label之间的层次关系表示出来。在这样一种结构中,我们采用了一种图或者叫WordNet来进行表示,通过WordTree来找到标签与标签之间的关系以及包含关系。

在具体训练的时候,如果一副图片的label是拿到更多的一些label,比如说不仅是狗,同时也是哺乳动物,同时是犬科,也可能是家畜。那这些label就会同时作为这个图片的标记,换句话说对于一副图片就会产生多个标记,标签之间不需要相互独立。对于ImageNet分类任务而言,它使用一个大的SoftMax就能够完成分类任务。而WordTree则需要对同一概念下的同义词进行SoftMax,这样做的好处就在于对一些未知的新的物体在进行分类的时候,整体的性能降低是很优雅的。比如看到一个狗的照片,但是不知道它属于哪种类别的狗,这个时候高置信度预测就是狗,而其他狗的类别同义词中,比如说哈士奇、金毛等这些词,它们的置信度就是低的置信度。作者通过这样的一种方式,将COCO检测数据集、ImageNet中的分类数据集来进行混合,利用混合之后的数据集,来进行检测和分类任务的训练。最终得到了YOLO9000这样一个性能更加优的分类器和检测器。YOLO9000能够完成9000个物体的检测和分类,并且能够保证较高的一个实时性。因此我们将YOLO9000称作YOLOV2更强的版本。

在上图中,对于ImageNet分类任务而言,我们需要针对每一个类别,通过一个较大的SoftMax来完成分类。而对于WordTree在进行SoftMax的时候,需要考虑label和label之间的关系,考虑这些label和label之间的关系之后,再通过对同一概念下的同义词进行SoftMax分类来完成最终的分类loss的计算。通过联合训练策略,YOLO9000可以快速检测出超过9000个类别的物体,总体mAP值为19.7%。

YOLOV3

YOLOV3相比于V1、V2,更多的考虑的是速度和精度的均衡,融合了更多先进的方法,重点解决了小物体检测的问题。

  • YOLOV3改进策略:

1、首先在主干网络上进行了优化,采用了类似ResNet的网络结构,来提取更多更优的特征表示。

如上图所示,采用ResNet网络结构,能够获取到更加好的检测效果,当然采用更深层的网络结构会带来检测速度上的下降。这也是在速度和精度上的一种平衡。

2、采用了多尺度的预测,类如FPN的结构来提高检测的精度。

在上图的右下角我们可以看到,V3分别从不同尺度的feature map上来提取特征,作为YOLO检测的输入。对于Anchor的设计,同样采用聚类的方法来获得最终的长宽比。通过聚类之后得到9个簇(聚类中心),将这9个簇平均分到了3种尺度上,每一种尺度预测3个Bounding box。对于每一种尺度,作者会引入一些卷积层来进一步的提取特征,之后再输出Bounding box的信息。对于尺度1而言,作者直接卷积之后直接输出Bounding box的信息。对于尺度2而言,作者在输出Bounding box之前,会对尺度1输出的卷积进行上采样,然后同卷积2的feature map进行相加,相加之后再输出到后续的Bounding box的信息。整个feature map尺寸的大小相对尺度1而言,扩大了两倍。尺度3相对于尺度2而言,同样也扩大了两倍。它的输入同样也是在尺度2上经过上采样,来得到的feature map的大小加上原先的feature map的大小,之后再通过卷积输出最后的Bounding box的信息。整个结构也是类似于FPN的一种结构。

3、采用了更好的分类器(binary cross-entropy loss二值交叉熵)来完成分类任务。

主要原因就在于Softmax在对每一个Bounding box进行分类的时候只能分配一个类别,就是分数最大的那个类别,最终会输出一个概率分布,概率分布最高的那个值作为当前Bounding box的类别。当前的目标如果存在重叠的目标标签的时候,Softmax就不适合这种多标签分类的问题。实际上Softmax可以通过多个logistic分类器替代,且准确度不会下降。

  • YOLOV3网络性能

通过上图我们可以看到,对比YOLOV2的网络结构,V3能够实现更好的效果,由于上图中V3采用的是Darknet,相比于其他采用ResNet的结构,性能会有一些下降。

对于YOLOV3本身采用不同的主干网络,采用ResNet-152的时候,它的整体性能能够达到最好的效果。

对于COCO数据集,这里也给出了一个性能对比,YOLOV3对比于其他的目标识别网络结构,同样也能达到一个比较好的性能的优势。但YOLOV3整体的检测速度会有所下降,但相比于其他的目标检测算法,检测速度依然会更快。

VOLOV3的框架源码是由Darknet框架完成的,Darknet框架是由C语言和CUDA实现的,对GPU显存的利用率较高,对第三方的依赖库较少。容易实现跨平台接口的移植,能够较好的应用于Windows或者嵌入式设备中。Darknet也是实现深度网络很好的一种框架。

现在我们来看一下YOLOV3的代码结构,这里依然以Darknet作为V3的主干网络

import tensorflow as tf
from tensorflow.keras import layers, regularizers, models


class DarkNet:
    def __init__(self):
        pass

    def _darknet_conv(self, x, filters, size, strides=1, batch_norm=True):
        if strides == 1:
            padding = 'same'
        else:
            # 对输入的图像矩阵上、左各添加一行(列)的0来作为padding
            x = layers.ZeroPadding2D(((1, 0), (1, 0)))(x)  # top left half-padding
            padding = 'valid'
        x = layers.Conv2D(filters, (size, size),
                          strides=strides,
                          padding=padding,
                          use_bias=not batch_norm,
                          kernel_regularizer=regularizers.l2(0.0005))(x)
        if batch_norm:
            x = layers.BatchNormalization()(x)
            x = layers.LeakyReLU(alpha=0.1)(x)
        return x

    def _darknet_residual(self, x, filters):
        prev = x
        x = self._darknet_conv(x, filters // 2, 1)
        x = self._darknet_conv(x, filters, 3)
        x = layers.Add()([prev, x])
        return x

    def _darknet_block(self, x, filters, blocks):
        x = self._darknet_conv(x, filters, 3, strides=2)
        for _ in range(blocks):
            x = self._darknet_residual(x, filters)
        return x

    def build_darknet(self, x, name=None):
        # x = inputs = tf.keras.layers.Input([None, None, 3])
        x = self._darknet_conv(x, 32, 3)
        # 1/2
        x = self._darknet_block(x, 64, 1)
        # 1/4
        x = self._darknet_block(x, 128, 2)
        # 1/8
        x = x1 = self._darknet_block(x, 256, 8)
        # 1/16
        x = x2 = self._darknet_block(x, 512, 8)
        # 1/32
        x3 = self._darknet_block(x, 1024, 4)
        # return tf.keras.Model(inputs, (x_36, x_61, x), name=name)
        return x1, x2, x3

    def build_darknet_tiny(self, x, name=None):
        # x = inputs = tf.keras.layers.Input([None, None, 3])
        x = self._darknet_conv(x, 16, 3)
        x = layers.MaxPool2D(2, 2, 'same')(x)
        x = self._darknet_conv(x, 32, 3)
        x = layers.MaxPool2D(2, 2, 'same')(x)
        x = self._darknet_conv(x, 64, 3)
        x = layers.MaxPool2D(2, 2, 'same')(x)
        x = self._darknet_conv(x, 128, 3)
        x = layers.MaxPool2D(2, 2, 'same')(x)
        x = x_8 = self._darknet_conv(x, 256, 3)  # skip connection
        x = layers.MaxPool2D(2, 2, 'same')(x)
        x = self._darknet_conv(x, 512, 3)
        x = layers.MaxPool2D(2, 1, 'same')(x)
        x = self._darknet_conv(x, 1024, 3)
        # return tf.keras.Model(inputs, (x_8, x), name=name)
        return x_8, x

if __name__ == '__main__':
    # import os
    # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    darknet = DarkNet()
    x = layers.Input(shape=(500, 600, 3))
    darknet_model = darknet.build_darknet(x)
    model = models.Model(x, darknet_model)
    print(model.summary())

运行结果

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 500, 600, 3) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 500, 600, 32) 864         input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 500, 600, 32) 128         conv2d[0][0]                     
__________________________________________________________________________________________________
leaky_re_lu (LeakyReLU)         (None, 500, 600, 32) 0           batch_normalization[0][0]        
__________________________________________________________________________________________________
zero_padding2d (ZeroPadding2D)  (None, 501, 601, 32) 0           leaky_re_lu[0][0]                
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 250, 300, 64) 18432       zero_padding2d[0][0]             
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 250, 300, 64) 256         conv2d_1[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_1 (LeakyReLU)       (None, 250, 300, 64) 0           batch_normalization_1[0][0]      
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 250, 300, 32) 2048        leaky_re_lu_1[0][0]              
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 250, 300, 32) 128         conv2d_2[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_2 (LeakyReLU)       (None, 250, 300, 32) 0           batch_normalization_2[0][0]      
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 250, 300, 64) 18432       leaky_re_lu_2[0][0]              
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 250, 300, 64) 256         conv2d_3[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_3 (LeakyReLU)       (None, 250, 300, 64) 0           batch_normalization_3[0][0]      
__________________________________________________________________________________________________
add (Add)                       (None, 250, 300, 64) 0           leaky_re_lu_1[0][0]              
                                                                 leaky_re_lu_3[0][0]              
__________________________________________________________________________________________________
zero_padding2d_1 (ZeroPadding2D (None, 251, 301, 64) 0           add[0][0]                        
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 125, 150, 128 73728       zero_padding2d_1[0][0]           
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 125, 150, 128 512         conv2d_4[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_4 (LeakyReLU)       (None, 125, 150, 128 0           batch_normalization_4[0][0]      
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 125, 150, 64) 8192        leaky_re_lu_4[0][0]              
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 125, 150, 64) 256         conv2d_5[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_5 (LeakyReLU)       (None, 125, 150, 64) 0           batch_normalization_5[0][0]      
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 125, 150, 128 73728       leaky_re_lu_5[0][0]              
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 125, 150, 128 512         conv2d_6[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_6 (LeakyReLU)       (None, 125, 150, 128 0           batch_normalization_6[0][0]      
__________________________________________________________________________________________________
add_1 (Add)                     (None, 125, 150, 128 0           leaky_re_lu_4[0][0]              
                                                                 leaky_re_lu_6[0][0]              
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 125, 150, 64) 8192        add_1[0][0]                      
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 125, 150, 64) 256         conv2d_7[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_7 (LeakyReLU)       (None, 125, 150, 64) 0           batch_normalization_7[0][0]      
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 125, 150, 128 73728       leaky_re_lu_7[0][0]              
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 125, 150, 128 512         conv2d_8[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_8 (LeakyReLU)       (None, 125, 150, 128 0           batch_normalization_8[0][0]      
__________________________________________________________________________________________________
add_2 (Add)                     (None, 125, 150, 128 0           add_1[0][0]                      
                                                                 leaky_re_lu_8[0][0]              
__________________________________________________________________________________________________
zero_padding2d_2 (ZeroPadding2D (None, 126, 151, 128 0           add_2[0][0]                      
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 62, 75, 256)  294912      zero_padding2d_2[0][0]           
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 62, 75, 256)  1024        conv2d_9[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_9 (LeakyReLU)       (None, 62, 75, 256)  0           batch_normalization_9[0][0]      
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 62, 75, 128)  32768       leaky_re_lu_9[0][0]              
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 62, 75, 128)  512         conv2d_10[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_10 (LeakyReLU)      (None, 62, 75, 128)  0           batch_normalization_10[0][0]     
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 62, 75, 256)  294912      leaky_re_lu_10[0][0]             
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 62, 75, 256)  1024        conv2d_11[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_11 (LeakyReLU)      (None, 62, 75, 256)  0           batch_normalization_11[0][0]     
__________________________________________________________________________________________________
add_3 (Add)                     (None, 62, 75, 256)  0           leaky_re_lu_9[0][0]              
                                                                 leaky_re_lu_11[0][0]             
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 62, 75, 128)  32768       add_3[0][0]                      
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, 62, 75, 128)  512         conv2d_12[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_12 (LeakyReLU)      (None, 62, 75, 128)  0           batch_normalization_12[0][0]     
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 62, 75, 256)  294912      leaky_re_lu_12[0][0]             
__________________________________________________________________________________________________
batch_normalization_13 (BatchNo (None, 62, 75, 256)  1024        conv2d_13[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_13 (LeakyReLU)      (None, 62, 75, 256)  0           batch_normalization_13[0][0]     
__________________________________________________________________________________________________
add_4 (Add)                     (None, 62, 75, 256)  0           add_3[0][0]                      
                                                                 leaky_re_lu_13[0][0]             
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 62, 75, 128)  32768       add_4[0][0]                      
__________________________________________________________________________________________________
batch_normalization_14 (BatchNo (None, 62, 75, 128)  512         conv2d_14[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_14 (LeakyReLU)      (None, 62, 75, 128)  0           batch_normalization_14[0][0]     
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 62, 75, 256)  294912      leaky_re_lu_14[0][0]             
__________________________________________________________________________________________________
batch_normalization_15 (BatchNo (None, 62, 75, 256)  1024        conv2d_15[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_15 (LeakyReLU)      (None, 62, 75, 256)  0           batch_normalization_15[0][0]     
__________________________________________________________________________________________________
add_5 (Add)                     (None, 62, 75, 256)  0           add_4[0][0]                      
                                                                 leaky_re_lu_15[0][0]             
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 62, 75, 128)  32768       add_5[0][0]                      
__________________________________________________________________________________________________
batch_normalization_16 (BatchNo (None, 62, 75, 128)  512         conv2d_16[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_16 (LeakyReLU)      (None, 62, 75, 128)  0           batch_normalization_16[0][0]     
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 62, 75, 256)  294912      leaky_re_lu_16[0][0]             
__________________________________________________________________________________________________
batch_normalization_17 (BatchNo (None, 62, 75, 256)  1024        conv2d_17[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_17 (LeakyReLU)      (None, 62, 75, 256)  0           batch_normalization_17[0][0]     
__________________________________________________________________________________________________
add_6 (Add)                     (None, 62, 75, 256)  0           add_5[0][0]                      
                                                                 leaky_re_lu_17[0][0]             
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 62, 75, 128)  32768       add_6[0][0]                      
__________________________________________________________________________________________________
batch_normalization_18 (BatchNo (None, 62, 75, 128)  512         conv2d_18[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_18 (LeakyReLU)      (None, 62, 75, 128)  0           batch_normalization_18[0][0]     
__________________________________________________________________________________________________
conv2d_19 (Conv2D)              (None, 62, 75, 256)  294912      leaky_re_lu_18[0][0]             
__________________________________________________________________________________________________
batch_normalization_19 (BatchNo (None, 62, 75, 256)  1024        conv2d_19[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_19 (LeakyReLU)      (None, 62, 75, 256)  0           batch_normalization_19[0][0]     
__________________________________________________________________________________________________
add_7 (Add)                     (None, 62, 75, 256)  0           add_6[0][0]                      
                                                                 leaky_re_lu_19[0][0]             
__________________________________________________________________________________________________
conv2d_20 (Conv2D)              (None, 62, 75, 128)  32768       add_7[0][0]                      
__________________________________________________________________________________________________
batch_normalization_20 (BatchNo (None, 62, 75, 128)  512         conv2d_20[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_20 (LeakyReLU)      (None, 62, 75, 128)  0           batch_normalization_20[0][0]     
__________________________________________________________________________________________________
conv2d_21 (Conv2D)              (None, 62, 75, 256)  294912      leaky_re_lu_20[0][0]             
__________________________________________________________________________________________________
batch_normalization_21 (BatchNo (None, 62, 75, 256)  1024        conv2d_21[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_21 (LeakyReLU)      (None, 62, 75, 256)  0           batch_normalization_21[0][0]     
__________________________________________________________________________________________________
add_8 (Add)                     (None, 62, 75, 256)  0           add_7[0][0]                      
                                                                 leaky_re_lu_21[0][0]             
__________________________________________________________________________________________________
conv2d_22 (Conv2D)              (None, 62, 75, 128)  32768       add_8[0][0]                      
__________________________________________________________________________________________________
batch_normalization_22 (BatchNo (None, 62, 75, 128)  512         conv2d_22[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_22 (LeakyReLU)      (None, 62, 75, 128)  0           batch_normalization_22[0][0]     
__________________________________________________________________________________________________
conv2d_23 (Conv2D)              (None, 62, 75, 256)  294912      leaky_re_lu_22[0][0]             
__________________________________________________________________________________________________
batch_normalization_23 (BatchNo (None, 62, 75, 256)  1024        conv2d_23[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_23 (LeakyReLU)      (None, 62, 75, 256)  0           batch_normalization_23[0][0]     
__________________________________________________________________________________________________
add_9 (Add)                     (None, 62, 75, 256)  0           add_8[0][0]                      
                                                                 leaky_re_lu_23[0][0]             
__________________________________________________________________________________________________
conv2d_24 (Conv2D)              (None, 62, 75, 128)  32768       add_9[0][0]                      
__________________________________________________________________________________________________
batch_normalization_24 (BatchNo (None, 62, 75, 128)  512         conv2d_24[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_24 (LeakyReLU)      (None, 62, 75, 128)  0           batch_normalization_24[0][0]     
__________________________________________________________________________________________________
conv2d_25 (Conv2D)              (None, 62, 75, 256)  294912      leaky_re_lu_24[0][0]             
__________________________________________________________________________________________________
batch_normalization_25 (BatchNo (None, 62, 75, 256)  1024        conv2d_25[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_25 (LeakyReLU)      (None, 62, 75, 256)  0           batch_normalization_25[0][0]     
__________________________________________________________________________________________________
add_10 (Add)                    (None, 62, 75, 256)  0           add_9[0][0]                      
                                                                 leaky_re_lu_25[0][0]             
__________________________________________________________________________________________________
zero_padding2d_3 (ZeroPadding2D (None, 63, 76, 256)  0           add_10[0][0]                     
__________________________________________________________________________________________________
conv2d_26 (Conv2D)              (None, 31, 37, 512)  1179648     zero_padding2d_3[0][0]           
__________________________________________________________________________________________________
batch_normalization_26 (BatchNo (None, 31, 37, 512)  2048        conv2d_26[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_26 (LeakyReLU)      (None, 31, 37, 512)  0           batch_normalization_26[0][0]     
__________________________________________________________________________________________________
conv2d_27 (Conv2D)              (None, 31, 37, 256)  131072      leaky_re_lu_26[0][0]             
__________________________________________________________________________________________________
batch_normalization_27 (BatchNo (None, 31, 37, 256)  1024        conv2d_27[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_27 (LeakyReLU)      (None, 31, 37, 256)  0           batch_normalization_27[0][0]     
__________________________________________________________________________________________________
conv2d_28 (Conv2D)              (None, 31, 37, 512)  1179648     leaky_re_lu_27[0][0]             
__________________________________________________________________________________________________
batch_normalization_28 (BatchNo (None, 31, 37, 512)  2048        conv2d_28[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_28 (LeakyReLU)      (None, 31, 37, 512)  0           batch_normalization_28[0][0]     
__________________________________________________________________________________________________
add_11 (Add)                    (None, 31, 37, 512)  0           leaky_re_lu_26[0][0]             
                                                                 leaky_re_lu_28[0][0]             
__________________________________________________________________________________________________
conv2d_29 (Conv2D)              (None, 31, 37, 256)  131072      add_11[0][0]                     
__________________________________________________________________________________________________
batch_normalization_29 (BatchNo (None, 31, 37, 256)  1024        conv2d_29[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_29 (LeakyReLU)      (None, 31, 37, 256)  0           batch_normalization_29[0][0]     
__________________________________________________________________________________________________
conv2d_30 (Conv2D)              (None, 31, 37, 512)  1179648     leaky_re_lu_29[0][0]             
__________________________________________________________________________________________________
batch_normalization_30 (BatchNo (None, 31, 37, 512)  2048        conv2d_30[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_30 (LeakyReLU)      (None, 31, 37, 512)  0           batch_normalization_30[0][0]     
__________________________________________________________________________________________________
add_12 (Add)                    (None, 31, 37, 512)  0           add_11[0][0]                     
                                                                 leaky_re_lu_30[0][0]             
__________________________________________________________________________________________________
conv2d_31 (Conv2D)              (None, 31, 37, 256)  131072      add_12[0][0]                     
__________________________________________________________________________________________________
batch_normalization_31 (BatchNo (None, 31, 37, 256)  1024        conv2d_31[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_31 (LeakyReLU)      (None, 31, 37, 256)  0           batch_normalization_31[0][0]     
__________________________________________________________________________________________________
conv2d_32 (Conv2D)              (None, 31, 37, 512)  1179648     leaky_re_lu_31[0][0]             
__________________________________________________________________________________________________
batch_normalization_32 (BatchNo (None, 31, 37, 512)  2048        conv2d_32[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_32 (LeakyReLU)      (None, 31, 37, 512)  0           batch_normalization_32[0][0]     
__________________________________________________________________________________________________
add_13 (Add)                    (None, 31, 37, 512)  0           add_12[0][0]                     
                                                                 leaky_re_lu_32[0][0]             
__________________________________________________________________________________________________
conv2d_33 (Conv2D)              (None, 31, 37, 256)  131072      add_13[0][0]                     
__________________________________________________________________________________________________
batch_normalization_33 (BatchNo (None, 31, 37, 256)  1024        conv2d_33[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_33 (LeakyReLU)      (None, 31, 37, 256)  0           batch_normalization_33[0][0]     
__________________________________________________________________________________________________
conv2d_34 (Conv2D)              (None, 31, 37, 512)  1179648     leaky_re_lu_33[0][0]             
__________________________________________________________________________________________________
batch_normalization_34 (BatchNo (None, 31, 37, 512)  2048        conv2d_34[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_34 (LeakyReLU)      (None, 31, 37, 512)  0           batch_normalization_34[0][0]     
__________________________________________________________________________________________________
add_14 (Add)                    (None, 31, 37, 512)  0           add_13[0][0]                     
                                                                 leaky_re_lu_34[0][0]             
__________________________________________________________________________________________________
conv2d_35 (Conv2D)              (None, 31, 37, 256)  131072      add_14[0][0]                     
__________________________________________________________________________________________________
batch_normalization_35 (BatchNo (None, 31, 37, 256)  1024        conv2d_35[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_35 (LeakyReLU)      (None, 31, 37, 256)  0           batch_normalization_35[0][0]     
__________________________________________________________________________________________________
conv2d_36 (Conv2D)              (None, 31, 37, 512)  1179648     leaky_re_lu_35[0][0]             
__________________________________________________________________________________________________
batch_normalization_36 (BatchNo (None, 31, 37, 512)  2048        conv2d_36[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_36 (LeakyReLU)      (None, 31, 37, 512)  0           batch_normalization_36[0][0]     
__________________________________________________________________________________________________
add_15 (Add)                    (None, 31, 37, 512)  0           add_14[0][0]                     
                                                                 leaky_re_lu_36[0][0]             
__________________________________________________________________________________________________
conv2d_37 (Conv2D)              (None, 31, 37, 256)  131072      add_15[0][0]                     
__________________________________________________________________________________________________
batch_normalization_37 (BatchNo (None, 31, 37, 256)  1024        conv2d_37[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_37 (LeakyReLU)      (None, 31, 37, 256)  0           batch_normalization_37[0][0]     
__________________________________________________________________________________________________
conv2d_38 (Conv2D)              (None, 31, 37, 512)  1179648     leaky_re_lu_37[0][0]             
__________________________________________________________________________________________________
batch_normalization_38 (BatchNo (None, 31, 37, 512)  2048        conv2d_38[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_38 (LeakyReLU)      (None, 31, 37, 512)  0           batch_normalization_38[0][0]     
__________________________________________________________________________________________________
add_16 (Add)                    (None, 31, 37, 512)  0           add_15[0][0]                     
                                                                 leaky_re_lu_38[0][0]             
__________________________________________________________________________________________________
conv2d_39 (Conv2D)              (None, 31, 37, 256)  131072      add_16[0][0]                     
__________________________________________________________________________________________________
batch_normalization_39 (BatchNo (None, 31, 37, 256)  1024        conv2d_39[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_39 (LeakyReLU)      (None, 31, 37, 256)  0           batch_normalization_39[0][0]     
__________________________________________________________________________________________________
conv2d_40 (Conv2D)              (None, 31, 37, 512)  1179648     leaky_re_lu_39[0][0]             
__________________________________________________________________________________________________
batch_normalization_40 (BatchNo (None, 31, 37, 512)  2048        conv2d_40[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_40 (LeakyReLU)      (None, 31, 37, 512)  0           batch_normalization_40[0][0]     
__________________________________________________________________________________________________
add_17 (Add)                    (None, 31, 37, 512)  0           add_16[0][0]                     
                                                                 leaky_re_lu_40[0][0]             
__________________________________________________________________________________________________
conv2d_41 (Conv2D)              (None, 31, 37, 256)  131072      add_17[0][0]                     
__________________________________________________________________________________________________
batch_normalization_41 (BatchNo (None, 31, 37, 256)  1024        conv2d_41[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_41 (LeakyReLU)      (None, 31, 37, 256)  0           batch_normalization_41[0][0]     
__________________________________________________________________________________________________
conv2d_42 (Conv2D)              (None, 31, 37, 512)  1179648     leaky_re_lu_41[0][0]             
__________________________________________________________________________________________________
batch_normalization_42 (BatchNo (None, 31, 37, 512)  2048        conv2d_42[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_42 (LeakyReLU)      (None, 31, 37, 512)  0           batch_normalization_42[0][0]     
__________________________________________________________________________________________________
add_18 (Add)                    (None, 31, 37, 512)  0           add_17[0][0]                     
                                                                 leaky_re_lu_42[0][0]             
__________________________________________________________________________________________________
zero_padding2d_4 (ZeroPadding2D (None, 32, 38, 512)  0           add_18[0][0]                     
__________________________________________________________________________________________________
conv2d_43 (Conv2D)              (None, 15, 18, 1024) 4718592     zero_padding2d_4[0][0]           
__________________________________________________________________________________________________
batch_normalization_43 (BatchNo (None, 15, 18, 1024) 4096        conv2d_43[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_43 (LeakyReLU)      (None, 15, 18, 1024) 0           batch_normalization_43[0][0]     
__________________________________________________________________________________________________
conv2d_44 (Conv2D)              (None, 15, 18, 512)  524288      leaky_re_lu_43[0][0]             
__________________________________________________________________________________________________
batch_normalization_44 (BatchNo (None, 15, 18, 512)  2048        conv2d_44[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_44 (LeakyReLU)      (None, 15, 18, 512)  0           batch_normalization_44[0][0]     
__________________________________________________________________________________________________
conv2d_45 (Conv2D)              (None, 15, 18, 1024) 4718592     leaky_re_lu_44[0][0]             
__________________________________________________________________________________________________
batch_normalization_45 (BatchNo (None, 15, 18, 1024) 4096        conv2d_45[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_45 (LeakyReLU)      (None, 15, 18, 1024) 0           batch_normalization_45[0][0]     
__________________________________________________________________________________________________
add_19 (Add)                    (None, 15, 18, 1024) 0           leaky_re_lu_43[0][0]             
                                                                 leaky_re_lu_45[0][0]             
__________________________________________________________________________________________________
conv2d_46 (Conv2D)              (None, 15, 18, 512)  524288      add_19[0][0]                     
__________________________________________________________________________________________________
batch_normalization_46 (BatchNo (None, 15, 18, 512)  2048        conv2d_46[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_46 (LeakyReLU)      (None, 15, 18, 512)  0           batch_normalization_46[0][0]     
__________________________________________________________________________________________________
conv2d_47 (Conv2D)              (None, 15, 18, 1024) 4718592     leaky_re_lu_46[0][0]             
__________________________________________________________________________________________________
batch_normalization_47 (BatchNo (None, 15, 18, 1024) 4096        conv2d_47[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_47 (LeakyReLU)      (None, 15, 18, 1024) 0           batch_normalization_47[0][0]     
__________________________________________________________________________________________________
add_20 (Add)                    (None, 15, 18, 1024) 0           add_19[0][0]                     
                                                                 leaky_re_lu_47[0][0]             
__________________________________________________________________________________________________
conv2d_48 (Conv2D)              (None, 15, 18, 512)  524288      add_20[0][0]                     
__________________________________________________________________________________________________
batch_normalization_48 (BatchNo (None, 15, 18, 512)  2048        conv2d_48[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_48 (LeakyReLU)      (None, 15, 18, 512)  0           batch_normalization_48[0][0]     
__________________________________________________________________________________________________
conv2d_49 (Conv2D)              (None, 15, 18, 1024) 4718592     leaky_re_lu_48[0][0]             
__________________________________________________________________________________________________
batch_normalization_49 (BatchNo (None, 15, 18, 1024) 4096        conv2d_49[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_49 (LeakyReLU)      (None, 15, 18, 1024) 0           batch_normalization_49[0][0]     
__________________________________________________________________________________________________
add_21 (Add)                    (None, 15, 18, 1024) 0           add_20[0][0]                     
                                                                 leaky_re_lu_49[0][0]             
__________________________________________________________________________________________________
conv2d_50 (Conv2D)              (None, 15, 18, 512)  524288      add_21[0][0]                     
__________________________________________________________________________________________________
batch_normalization_50 (BatchNo (None, 15, 18, 512)  2048        conv2d_50[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_50 (LeakyReLU)      (None, 15, 18, 512)  0           batch_normalization_50[0][0]     
__________________________________________________________________________________________________
conv2d_51 (Conv2D)              (None, 15, 18, 1024) 4718592     leaky_re_lu_50[0][0]             
__________________________________________________________________________________________________
batch_normalization_51 (BatchNo (None, 15, 18, 1024) 4096        conv2d_51[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_51 (LeakyReLU)      (None, 15, 18, 1024) 0           batch_normalization_51[0][0]     
__________________________________________________________________________________________________
add_22 (Add)                    (None, 15, 18, 1024) 0           add_21[0][0]                     
                                                                 leaky_re_lu_51[0][0]             
==================================================================================================
Total params: 40,620,640
Trainable params: 40,584,928
Non-trainable params: 35,712
__________________________________________________________________________________________________
None

 现在我们同样以一张真实图片放进DarkNet网络来看一下它的输出

import tensorflow as tf
from tensorflow.keras import layers, preprocessing, backend, models, optimizers, losses
import numpy as np
from skimage.transform import resize
import cv2
from darknet import DarkNet
from data.generate_coco_data import CoCoDataGenrator

if __name__ == '__main__':

    image_shape = (640, 640, 3)
    inputs = layers.Input(shape=image_shape, name='input_images')
    img = preprocessing.image.load_img("/Users/admin/Documents/6565.jpeg", color_mode='rgb')
    print(np.asarray(img).shape)
    img = preprocessing.image.img_to_array(img, dtype=np.uint8)
    img = resize(img, (640, 640, 3))
    r, g, b = cv2.split(img)
    img_new = cv2.merge((b, g, r))
    cv2.imshow('img', img_new)
    cv2.waitKey()
    img = tf.reshape(img, shape=(1, 640, 640, 3))
    img = tf.cast(img, dtype=tf.float32)
    darknet = DarkNet()
    x1, x2, x3 = darknet.build_darknet(img, "darknet")
    print(x1)
    print(x2)
    print(x3)

运行结果

(500, 600, 3)
tf.Tensor(
[[[[ 0.01300546  0.04580329  0.02618899 ...  0.0024302   0.00778028
     0.02046073]
   [ 0.01523247  0.0433247   0.03101609 ...  0.00859393  0.01517401
     0.04126479]
   [ 0.01569667  0.04655103  0.02909921 ...  0.01153814  0.0103563
     0.04232045]
   ...
   [ 0.01461979  0.05112193  0.03184253 ...  0.01749601  0.01063334
     0.04914174]
   [ 0.01122806  0.04039535  0.03295705 ...  0.01560631  0.00924575
     0.047635  ]
   [ 0.01934911  0.01768233  0.02554618 ...  0.02241347  0.01196023
     0.04969869]]

  [[ 0.01979208  0.04440274  0.03590046 ... -0.00037868  0.02252091
     0.01689843]
   [ 0.0281611   0.04216104  0.03869232 ...  0.00249345  0.0317611
     0.05443078]
   [ 0.02877221  0.0423179   0.04370406 ...  0.01127182  0.02932161
     0.05519786]
   ...
   [ 0.02331033  0.04598416  0.04644613 ...  0.0131951   0.03045727
     0.06477398]
   [ 0.0180282   0.03877962  0.04242241 ...  0.01323145  0.02318294
     0.06416938]
   [ 0.03270163  0.02084929  0.03383677 ...  0.02968623  0.01979706
     0.07166156]]

  [[ 0.02432447  0.050202    0.03617659 ... -0.00233297  0.03178607
     0.01721654]
   [ 0.02453477  0.04319714  0.04343178 ...  0.00594528  0.03310958
     0.06119607]
   [ 0.02629931  0.05030533  0.04756507 ...  0.01565211  0.02879084
     0.06568647]
   ...
   [ 0.02312282  0.04931618  0.05300389 ...  0.02371687  0.02887772
     0.07611271]
   [ 0.01672611  0.04033546  0.05156515 ...  0.02270166  0.02590587
     0.07241585]
   [ 0.0344753   0.01985516  0.0468677  ...  0.03887149  0.02358301
     0.07878023]]

  ...

  [[ 0.01810458  0.03465767  0.0359303  ... -0.00219411  0.02293537
     0.01476098]
   [ 0.0134875   0.02783474  0.0325518  ...  0.00836905  0.01697291
     0.05196021]
   [ 0.01130611  0.03472897  0.04398032 ...  0.01915504  0.01711655
     0.06013133]
   ...
   [ 0.01296154  0.03170425  0.06206837 ...  0.03199976  0.02187147
     0.08614711]
   [ 0.01051291  0.026424    0.05665638 ...  0.02734326  0.01725644
     0.08029547]
   [ 0.02285716  0.01315682  0.05187852 ...  0.04164089  0.02339076
     0.07870194]]

  [[ 0.01582998  0.03424459  0.03260812 ... -0.00325415  0.02369688
     0.01293045]
   [ 0.01067719  0.0237003   0.03294854 ...  0.00414321  0.01935142
     0.04889458]
   [ 0.00755305  0.02681288  0.04651605 ...  0.01095157  0.02234309
     0.05407221]
   ...
   [ 0.00872021  0.02522171  0.06343129 ...  0.02143701  0.02625748
     0.07797993]
   [ 0.00934438  0.02784847  0.05337022 ...  0.01815011  0.0163433
     0.07494247]
   [ 0.02527687  0.01019516  0.0487967  ...  0.03363573  0.02205626
     0.07579175]]

  [[ 0.01823313  0.01269497  0.02201747 ...  0.00057489  0.01901229
     0.01401678]
   [ 0.01871967  0.01256928  0.02453714 ...  0.00147328  0.02255515
     0.04461277]
   [ 0.01581922  0.01575183  0.03556562 ...  0.00610318  0.02761898
     0.04809618]
   ...
   [ 0.02686368  0.01529721  0.04575747 ...  0.01040503  0.03783914
     0.06418921]
   [ 0.02419975  0.01501905  0.04164971 ...  0.01177746  0.03296132
     0.05938878]
   [ 0.02779927  0.01238328  0.04344043 ...  0.02279663  0.02802737
     0.05404975]]]], shape=(1, 80, 80, 256), dtype=float32)
tf.Tensor(
[[[[ 5.57744503e-02  4.37109694e-02  1.97642922e-01 ...  1.28797188e-01
     3.13888304e-02  1.42925873e-01]
   [ 4.39740494e-02  2.06579026e-02  2.41064608e-01 ...  2.13760287e-01
     7.33558387e-02  2.18340456e-01]
   [ 2.09208392e-02  3.09265479e-02  2.46887892e-01 ...  2.80931830e-01
     8.50827098e-02  2.63412178e-01]
   ...
   [ 2.72087976e-02  5.25276959e-02  2.19686821e-01 ...  3.05240124e-01
     8.77861679e-02  2.72402823e-01]
   [ 2.65834890e-02  4.31933478e-02  1.73503757e-01 ...  2.72548974e-01
     8.33168477e-02  2.47264355e-01]
   [-1.60129666e-02  9.67956055e-03  1.04837202e-01 ...  2.23080724e-01
     4.11004983e-02  1.79972097e-01]]

  [[ 9.13270712e-02  4.25641313e-02  2.92546570e-01 ...  1.85710311e-01
     5.48139662e-02  2.48421669e-01]
   [ 6.77959770e-02  2.06972919e-02  4.27425057e-01 ...  3.10112268e-01
     1.27032176e-01  3.22345883e-01]
   [ 2.10890956e-02  2.81567629e-02  4.73780841e-01 ...  3.85221779e-01
     1.43670052e-01  3.64688814e-01]
   ...
   [ 5.77030051e-03  3.50399017e-02  4.49282169e-01 ...  3.97970021e-01
     1.31809086e-01  3.32577288e-01]
   [ 1.65432375e-02  1.57514010e-02  4.09302562e-01 ...  3.83579433e-01
     1.19915843e-01  3.17814291e-01]
   [-6.36246568e-03  5.90018928e-03  2.24725455e-01 ...  3.30318391e-01
     6.04600236e-02  1.92407310e-01]]

  [[ 9.07450020e-02  5.02355471e-02  3.22316527e-01 ...  2.17806578e-01
     7.86794573e-02  2.86677599e-01]
   [ 5.24090864e-02  2.40034275e-02  4.78858292e-01 ...  3.34466815e-01
     1.59046844e-01  3.56041044e-01]
   [-7.65481964e-04  5.54743968e-02  5.47300458e-01 ...  4.32219326e-01
     1.80383742e-01  4.10912067e-01]
   ...
   [-1.89762097e-02  6.06380999e-02  4.94807541e-01 ...  4.25486624e-01
     1.76484093e-01  3.82894129e-01]
   [-1.15267523e-02  4.02785651e-02  4.51946974e-01 ...  4.17101383e-01
     1.62846759e-01  3.76177728e-01]
   [-2.17601825e-02  1.85616724e-02  2.38932043e-01 ...  3.90599698e-01
     5.71805388e-02  2.44674429e-01]]

  ...

  [[ 6.35816827e-02  3.37417796e-02  2.64523119e-01 ...  1.85912058e-01
     7.50792101e-02  2.26906434e-01]
   [ 2.45644636e-02  2.49361265e-02  4.09977913e-01 ...  2.56267220e-01
     1.35049358e-01  2.58147597e-01]
   [-1.11659914e-02  5.32234684e-02  4.76467907e-01 ...  3.11896145e-01
     1.58500791e-01  2.89531261e-01]
   ...
   [-2.83652470e-02  6.99636936e-02  5.44180274e-01 ...  3.79703373e-01
     2.03064114e-01  3.48033041e-01]
   [-3.48832272e-02  5.73238991e-02  4.96190369e-01 ...  3.72090876e-01
     1.81635693e-01  3.58659029e-01]
   [-3.89259271e-02  2.73454078e-02  2.71928340e-01 ...  3.39304954e-01
     8.07696208e-02  2.54779845e-01]]

  [[ 5.60188815e-02  3.48543599e-02  2.20965892e-01 ...  1.62926540e-01
     7.29940161e-02  2.05437750e-01]
   [ 2.22457871e-02  4.96294089e-02  3.31111431e-01 ...  2.18296364e-01
     1.22494869e-01  2.22173691e-01]
   [-3.72293964e-03  8.88322666e-02  3.88356268e-01 ...  2.59815037e-01
     1.47945717e-01  2.44451314e-01]
   ...
   [-2.03779750e-02  1.10053554e-01  4.32410121e-01 ...  3.11534166e-01
     1.87783048e-01  3.03887814e-01]
   [-3.06747276e-02  9.53165665e-02  3.81398976e-01 ...  2.97086895e-01
     1.66954249e-01  2.94213474e-01]
   [-3.05685177e-02  3.21246460e-02  2.30680123e-01 ...  2.59103298e-01
     9.64258835e-02  2.30618730e-01]]

  [[ 4.33812030e-02  3.45876142e-02  1.36073053e-01 ...  7.24842548e-02
     5.62026761e-02  1.23662286e-01]
   [ 2.54868492e-02  5.30315749e-02  2.28306949e-01 ...  9.75592732e-02
     9.43436623e-02  1.28626943e-01]
   [ 1.92903653e-02  7.24871531e-02  2.70146757e-01 ...  1.20144218e-01
     1.10490926e-01  1.47676066e-01]
   ...
   [ 7.02063087e-03  8.36579055e-02  3.48761082e-01 ...  1.58494785e-01
     1.30590826e-01  1.79404482e-01]
   [-1.82694057e-05  5.76314554e-02  3.16185445e-01 ...  1.54782787e-01
     1.05770782e-01  1.70392156e-01]
   [-4.06836253e-03  1.12849586e-02  2.12644458e-01 ...  1.02986082e-01
     6.46547675e-02  9.39983502e-02]]]], shape=(1, 40, 40, 512), dtype=float32)
tf.Tensor(
[[[[ 0.12318148  0.03376219  0.01146538 ...  0.09772492  0.24040186
     0.14421844]
   [ 0.1259669   0.05908597  0.09456294 ...  0.2132321   0.3147785
     0.12912317]
   [ 0.09492658  0.10398786  0.15491882 ...  0.20309469  0.33134025
     0.15457296]
   ...
   [ 0.06671321  0.13512364  0.15603073 ...  0.20307465  0.3218756
     0.1176842 ]
   [ 0.03612547  0.13628691  0.15187964 ...  0.2019774   0.28994283
     0.10808527]
   [-0.0160175   0.14134313  0.18413983 ...  0.18274452  0.13057366
     0.09312407]]

  [[ 0.18831736  0.09769126  0.07227302 ...  0.18843201  0.2376497
     0.23451753]
   [ 0.23235133  0.20166397  0.18566903 ...  0.30471733  0.42872095
     0.23297027]
   [ 0.16126493  0.26441634  0.23911062 ...  0.29786062  0.48642734
     0.26872134]
   ...
   [ 0.11938853  0.28962046  0.258896   ...  0.29687464  0.51437193
     0.19304752]
   [ 0.06673403  0.28723678  0.26476678 ...  0.3100859   0.48218828
     0.1223008 ]
   [-0.02131327  0.25196818  0.40540588 ...  0.3097421   0.2831375
     0.03770523]]

  [[ 0.308563    0.09754747  0.10471658 ...  0.24433848  0.26193607
     0.3069915 ]
   [ 0.325613    0.1744133   0.2388102  ...  0.4038601   0.5247658
     0.2331565 ]
   [ 0.19707479  0.29004464  0.29805616 ...  0.41098198  0.5953533
     0.2674497 ]
   ...
   [ 0.11919326  0.33183312  0.32917818 ...  0.42651626  0.6684428
     0.19919254]
   [ 0.03990063  0.33553725  0.35605377 ...  0.46647537  0.6166377
     0.10448103]
   [-0.07907544  0.32569277  0.5182643  ...  0.41053414  0.36418986
     0.04715421]]

  ...

  [[ 0.38908362  0.11467192  0.10097787 ...  0.32609707  0.04908396
     0.33989137]
   [ 0.36390406  0.08983842  0.25493437 ...  0.45007908  0.46792823
     0.17120835]
   [ 0.27981567  0.22168614  0.31461418 ...  0.5019402   0.5556003
     0.17275102]
   ...
   [ 0.31684056  0.31729382  0.46586263 ...  0.6414348   0.5536462
     0.17497462]
   [ 0.25943795  0.3790891   0.6061166  ...  0.6894225   0.45387262
     0.10104082]
   [ 0.06759107  0.3362249   0.74389505 ...  0.5601219   0.5609029
    -0.09805419]]

  [[ 0.28793362  0.08725777  0.10377577 ...  0.19679242  0.05906167
     0.24279368]
   [ 0.26815853  0.07847413  0.20731884 ...  0.25489905  0.31678712
     0.15527718]
   [ 0.22508654  0.15286478  0.24512848 ...  0.31735316  0.37194827
     0.15467714]
   ...
   [ 0.2053708   0.259059    0.4087341  ...  0.36274028  0.37855875
     0.19536555]
   [ 0.17315325  0.30218142  0.49641716 ...  0.38937995  0.32239732
     0.12586012]
   [ 0.0407083   0.21762022  0.51479626 ...  0.34929705  0.36270672
    -0.04694114]]

  [[ 0.15374078  0.06020061  0.08242871 ...  0.18834715 -0.0276424
     0.1503744 ]
   [ 0.16523457  0.04991119  0.24927962 ...  0.13250573  0.1395475
     0.07700527]
   [ 0.13136236  0.08214448  0.25838062 ...  0.17917573  0.15435402
     0.05991633]
   ...
   [ 0.17718612  0.12624508  0.4046963  ...  0.17965443  0.16851594
     0.05702776]
   [ 0.16144499  0.13742751  0.47926506 ...  0.15510318  0.16570172
     0.02195632]
   [ 0.10069057  0.07071531  0.42017645 ...  0.1173479   0.2154476
    -0.0434314 ]]]], shape=(1, 20, 20, 1024), dtype=float32)

通过结果我们可以看到输出的x3是一个20*20的1024通道的feature map。通过几个卷积层的特征提取后,我们来看一下第一种尺度的确立。

if isinstance(x3, tuple):
    x, x_skip = x3[0], x3[1]

    # concat with skip connection
    x = darknet._darknet_conv(x, 512, 1)
    x = layers.UpSampling2D(2)(x)
    x = layers.Concatenate()([x, x_skip])
else:
    x = x3
# 继续提取特征
x = darknet._darknet_conv(x, 512, 1)
x = darknet._darknet_conv(x, 512 * 2, 3)
x = darknet._darknet_conv(x, 512, 1)
x = darknet._darknet_conv(x, 512 * 2, 3)
x = darknet._darknet_conv(x, 512, 1)
# 第一个连接点
concat_output = x

x = darknet._darknet_conv(x, 512 * 2, 3)
# 9个簇,3个尺度
anchor_masks = np.array([[6, 7, 8], [3, 4, 5], [0, 1, 2]])
# 3
num_anchors = len(anchor_masks[0])
# [batch, h, w, num_anchors * (num_class + 5)]
# 此处不使用批归一化和激活函数,91为分类的类别数,经过1*1的卷积变为通道数为288
x = darknet._darknet_conv(x, num_anchors * (91 + 5), 1, batch_norm=False)
# [batch, h, w, num_anchors, (num_class + 5)]
# 获取每一个Bounding box的座标偏移,宽高,置信度,91种分类值
x = layers.Lambda(lambda x: tf.reshape(x, (-1, tf.shape(x)[1], tf.shape(x)[2],
                                                    num_anchors, 91 + 5)))(x)
print('x-feature', x)

运行结果

x-feature tf.Tensor(
[[[[[ 1.43855521e-02  3.23734321e-02 -3.67170293e-03 ...
     -5.93810296e-03 -2.72588581e-02 -3.25208111e-03]
    [-3.53677664e-03  6.76192809e-03 -4.44413349e-03 ...
      2.82656588e-03 -7.69695640e-03  3.29773836e-02]
    [ 8.74534249e-04  3.14397179e-02  2.98833847e-02 ...
     -6.54635532e-03 -1.05559025e-02  8.90140049e-03]]

   [[ 4.17503864e-02  4.82537672e-02  6.04504673e-03 ...
      2.36063190e-02 -5.35690263e-02 -1.75136477e-02]
    [-8.13679397e-02 -1.36799887e-02 -1.80149078e-02 ...
     -5.34049422e-03  1.09536406e-02  8.79051313e-02]
    [-1.21820532e-02  3.43158655e-02  4.46718559e-02 ...
      4.26378567e-03  2.83443183e-02 -1.76916551e-02]]

   [[ 4.78246845e-02  7.02402964e-02 -2.98725674e-03 ...
      3.17356437e-02 -5.53366765e-02 -2.45371554e-02]
    [-1.09746695e-01 -1.01776635e-02 -1.15971267e-02 ...
      7.45519437e-03  1.98712982e-02  9.70957130e-02]
    [-1.79717652e-02  4.98642251e-02  6.01803586e-02 ...
      8.30962509e-03  3.16046104e-02 -2.41924915e-02]]

   ...

   [[ 7.06956089e-02  5.01229055e-02 -1.92262903e-02 ...
      2.69248001e-02 -7.39700571e-02 -3.62998098e-02]
    [-1.18457004e-01 -3.30782309e-03 -3.68624963e-02 ...
      6.42429944e-03  3.01392116e-02  1.19885460e-01]
    [-1.75386984e-02  5.80466166e-02  5.58605045e-02 ...
      2.99791712e-02  5.00387549e-02 -3.73844020e-02]]

   [[ 8.40673596e-02  3.32510434e-02 -4.41334471e-02 ...
      3.93143483e-03 -5.68560772e-02 -4.21492085e-02]
    [-9.49152410e-02  5.57698309e-04 -3.83891016e-02 ...
      7.42006395e-03  3.16541083e-02  1.02742508e-01]
    [-3.06121670e-02  4.80219983e-02  4.90897596e-02 ...
      4.73524630e-02  4.97985967e-02 -2.46338211e-02]]

   [[ 7.96302259e-02  1.91449262e-02 -3.81032005e-02 ...
     -1.58824027e-04 -4.31968197e-02 -2.81720422e-02]
    [-1.01859346e-01  1.41646825e-02 -1.52143966e-02 ...
     -5.00067556e-03  2.77400874e-02  7.79438317e-02]
    [-1.60612054e-02  9.17838793e-03  2.12875605e-02 ...
      3.83229628e-02  2.12987885e-02 -1.58014931e-02]]]


  [[[ 2.74948329e-02  5.05619273e-02 -2.29799170e-02 ...
     -7.93745508e-04 -5.51016741e-02  4.98068705e-03]
    [-4.01594229e-02  2.39265691e-02 -2.29697190e-02 ...
      2.10843701e-02  3.60490847e-03  7.17011243e-02]
    [-3.13397646e-02  3.12914401e-02  5.55699021e-02 ...
     -1.86149422e-02 -3.71585526e-02 -1.11365691e-02]]

   [[ 6.28519207e-02  6.23441786e-02 -1.86228883e-02 ...
      3.08187008e-02 -6.98437244e-02 -5.62309194e-03]
    [-1.26209423e-01  1.00298151e-02 -4.34425063e-02 ...
      3.09237503e-02  3.99011839e-03  1.37707442e-01]
    [-5.87570965e-02  6.70248717e-02  9.58192796e-02 ...
     -1.66721328e-03 -8.48327018e-03 -3.57441008e-02]]

   [[ 7.48019442e-02  7.09551796e-02 -4.55776900e-02 ...
      2.30621602e-02 -8.92516002e-02 -1.51880067e-02]
    [-1.76495656e-01  7.75676500e-03 -3.12610716e-02 ...
      3.46352980e-02 -1.50344521e-03  1.68000251e-01]
    [-6.90756738e-02  6.00972995e-02  9.68224928e-02 ...
      8.31374712e-03 -2.22650059e-02 -3.50602530e-02]]

   ...

   [[ 9.00407583e-02  6.28758818e-02 -7.87543431e-02 ...
      2.28788834e-02 -1.14216708e-01 -3.77769247e-02]
    [-1.90228492e-01  1.82287768e-05 -2.58688331e-02 ...
      4.48843502e-02  1.38927093e-02  2.05337286e-01]
    [-5.78798950e-02  1.05381131e-01  1.00754865e-01 ...
      3.34641226e-02  2.75914967e-02 -5.70844561e-02]]

   [[ 1.05826035e-01  4.62486446e-02 -8.39314014e-02 ...
      9.15204361e-03 -1.13347352e-01 -5.53399697e-02]
    [-1.50426179e-01 -3.87058128e-03 -3.00489590e-02 ...
      2.81358380e-02  1.88718829e-02  1.78275421e-01]
    [-5.05372919e-02  1.10970370e-01  7.74718523e-02 ...
      4.80326340e-02  2.88719907e-02 -2.23084521e-02]]

   [[ 9.67991948e-02  2.99865901e-02 -5.90101480e-02 ...
      2.17598621e-02 -7.23858923e-02 -5.25329933e-02]
    [-1.40194416e-01  3.17972936e-02 -1.44932307e-02 ...
      7.22579099e-03  3.50863002e-02  1.28113016e-01]
    [-3.52293998e-02  2.98271589e-02  4.73817810e-02 ...
      4.75752540e-02  3.86793353e-03  1.14865787e-03]]]


  [[[ 4.17320058e-02  7.47939944e-02 -3.67180519e-02 ...
      8.78189108e-04 -7.91597962e-02 -7.43087381e-03]
    [-3.72411422e-02  2.48172469e-02 -8.54682736e-03 ...
      3.10143922e-02  8.72607157e-03  9.71082598e-02]
    [-4.64725122e-02  3.81718948e-02  7.09193796e-02 ...
     -2.27786489e-02 -5.76945543e-02 -2.00277753e-02]]

   [[ 8.94888863e-02  9.55630019e-02 -2.68667266e-02 ...
      1.83212534e-02 -8.09980258e-02 -2.87768878e-02]
    [-1.29027143e-01  2.92052981e-03 -2.65594479e-02 ...
      5.60834520e-02 -7.63696432e-03  1.87813476e-01]
    [-8.47262442e-02  1.02939576e-01  9.95508805e-02 ...
     -5.55549329e-03 -4.15047444e-03 -4.96640876e-02]]

   [[ 1.16421722e-01  8.58199894e-02 -5.58324568e-02 ...
      7.33696949e-03 -1.09642014e-01 -4.48607728e-02]
    [-1.84635162e-01  1.18100038e-02 -8.70808028e-04 ...
      6.05232753e-02 -8.16491246e-03  2.26934075e-01]
    [-9.52242613e-02  9.61375087e-02  1.16436489e-01 ...
      1.51424110e-02 -6.67663850e-03 -4.60440591e-02]]

   ...

   [[ 1.53658032e-01  1.08256996e-01 -1.26264885e-01 ...
      1.84526388e-02 -1.51171982e-01 -7.67052174e-02]
    [-2.43474692e-01  1.49110164e-02 -1.30938403e-02 ...
      8.31634849e-02 -3.53191309e-02  2.80060172e-01]
    [-8.59394073e-02  1.34953350e-01  1.24888279e-01 ...
      5.64943552e-02  4.94653843e-02 -6.03685342e-02]]

   [[ 1.44291222e-01  8.66962001e-02 -1.22786529e-01 ...
      7.17543438e-03 -1.52220428e-01 -9.69632566e-02]
    [-2.01931000e-01 -1.39550492e-03 -2.16482691e-02 ...
      3.62620205e-02 -2.45773830e-02  2.31392995e-01]
    [-7.56062120e-02  1.36760980e-01  1.06802300e-01 ...
      6.82064295e-02  4.44108844e-02 -3.16042416e-02]]

   [[ 1.33012280e-01  4.28547040e-02 -9.10417140e-02 ...
      2.12642662e-02 -1.08710676e-01 -6.61423728e-02]
    [-1.81244701e-01  4.85137552e-02 -2.70637348e-02 ...
     -1.13902893e-02  1.35496520e-02  1.63882941e-01]
    [-5.21412939e-02  4.90978062e-02  6.27045110e-02 ...
      8.06145146e-02  1.60848722e-02  8.45119730e-03]]]


  ...


  [[[ 1.08628646e-02  1.11254916e-01 -8.07504803e-02 ...
      1.21952817e-02 -1.03620023e-01 -2.20873058e-02]
    [-6.67699724e-02  5.96636832e-02  5.51500320e-02 ...
      7.00847656e-02 -1.07938163e-02  8.37760791e-02]
    [-5.77524193e-02  5.51094189e-02  6.49683774e-02 ...
     -4.26605567e-02 -8.21135566e-03 -1.28405597e-02]]

   [[ 5.81651032e-02  1.49165615e-01 -1.81277283e-02 ...
      3.06064263e-02 -1.37036607e-01 -7.50004202e-02]
    [-1.37735859e-01  8.04366693e-02  7.41885602e-02 ...
      6.55304939e-02 -2.43935101e-02  1.74660489e-01]
    [-5.84301054e-02  7.85881653e-02  2.31113024e-02 ...
     -1.54732708e-02 -5.13509251e-02  9.48662870e-04]]

   [[ 1.06464945e-01  1.24650642e-01 -3.61565799e-02 ...
      1.52456388e-03 -1.68448240e-01 -1.38836905e-01]
    [-2.12003008e-01  5.89370839e-02  5.97102344e-02 ...
      5.59944659e-02 -2.49991696e-02  2.08926409e-01]
    [-5.65537736e-02  1.12186335e-01  3.25093418e-02 ...
      3.28003578e-02 -8.32641423e-02  1.00097563e-02]]

   ...

   [[ 1.51114792e-01  1.58362105e-01 -6.47453591e-02 ...
      5.14809787e-03 -2.83097416e-01 -2.57758439e-01]
    [-3.13894004e-01  7.31603056e-02  6.23276010e-02 ...
      7.82078877e-02 -4.45402861e-02  2.94473231e-01]
    [-8.32889974e-02  1.22369662e-01  8.58929753e-02 ...
      7.54653439e-02 -8.54809880e-02 -1.32431313e-02]]

   [[ 1.33512348e-01  1.26630753e-01 -4.58583161e-02 ...
     -1.28193274e-02 -2.54865557e-01 -2.46457934e-01]
    [-2.78497875e-01  8.39896873e-02 -2.95564532e-04 ...
     -8.82694125e-03 -6.67940453e-02  2.41247982e-01]
    [-6.61531016e-02  1.37382865e-01  7.88134113e-02 ...
      1.80205762e-01 -8.17115009e-02  2.97585502e-02]]

   [[ 1.29782706e-01  2.50800662e-02 -6.52197376e-02 ...
     -2.07002461e-02 -1.56805068e-01 -1.45573348e-01]
    [-2.44694024e-01  1.00220524e-01 -1.42512023e-01 ...
     -2.84452643e-02 -1.67766735e-02  1.54159695e-01]
    [-3.97469252e-02  3.13467085e-02  3.77028584e-02 ...
      1.52708039e-01  1.33550316e-02  3.96765023e-03]]]


  [[[ 3.32733244e-03  8.66670758e-02 -6.07146546e-02 ...
      1.44701917e-02 -7.41072074e-02 -1.12504112e-02]
    [-5.52796051e-02  6.39907420e-02  5.91539070e-02 ...
      4.56419885e-02 -1.43662700e-02  6.01993538e-02]
    [-4.13223058e-02  4.23230007e-02  5.12221456e-02 ...
     -2.44814064e-02  7.68753141e-03 -1.34938741e-02]]

   [[ 3.15343663e-02  1.12439543e-01 -1.23838335e-02 ...
      3.96174788e-02 -8.55981112e-02 -7.04407394e-02]
    [-1.02561615e-01  7.36098662e-02  6.62841946e-02 ...
      5.17353155e-02 -3.26524451e-02  1.02890044e-01]
    [-3.34513411e-02  4.52073961e-02  1.34149753e-02 ...
     -3.91144399e-03 -4.60506305e-02  1.18983788e-02]]

   [[ 4.87772003e-02  9.48870182e-02 -2.31062546e-02 ...
      1.21587627e-02 -1.22057237e-01 -1.12674981e-01]
    [-1.46121517e-01  4.62991670e-02  6.22609779e-02 ...
      5.49749732e-02 -2.72053480e-02  1.34005874e-01]
    [-3.90926264e-02  6.88402504e-02  2.43551135e-02 ...
      2.40953639e-02 -7.21580312e-02  6.27383497e-03]]

   ...

   [[ 6.31585419e-02  1.37103528e-01 -5.87803200e-02 ...
      4.47873026e-03 -2.04616517e-01 -1.89852357e-01]
    [-2.06350535e-01  7.93156251e-02  4.91615236e-02 ...
      2.46029869e-02 -3.90526429e-02  1.89593405e-01]
    [-4.61713150e-02  8.22851062e-02  6.44003153e-02 ...
      8.36576372e-02 -7.58771226e-02 -6.70336187e-03]]

   [[ 4.83803637e-02  1.05444193e-01 -4.93297055e-02 ...
     -2.84428932e-02 -1.93648577e-01 -1.60674170e-01]
    [-1.77676767e-01  7.94684216e-02  5.94073907e-03 ...
     -1.80556979e-02 -3.90497334e-02  1.71817034e-01]
    [-4.95870933e-02  8.17895085e-02  5.93550652e-02 ...
      1.39275029e-01 -7.31767714e-02  2.99034156e-02]]

   [[ 7.52071366e-02  4.28408384e-07 -4.29156683e-02 ...
     -2.15305649e-02 -1.18705362e-01 -1.08373582e-01]
    [-1.53610006e-01  8.31815824e-02 -1.10872447e-01 ...
     -2.87074856e-02 -4.76869289e-03  1.06467336e-01]
    [-1.55894952e-02  1.17892586e-02  2.64222510e-02 ...
      1.16865277e-01  1.46103539e-02  1.16869062e-03]]]


  [[[ 1.61800720e-03  2.64321100e-02 -3.98671329e-02 ...
      9.84302349e-03 -5.75484447e-02 -6.95041055e-03]
    [-1.87479239e-02  4.90886793e-02  5.64492717e-02 ...
      2.24619880e-02 -1.82566028e-02  2.19828263e-02]
    [-1.61643643e-02  1.01023884e-02  8.82678013e-03 ...
     -3.97317577e-04  6.44671265e-03  6.35148119e-03]]

   [[-5.73368184e-03  3.67868356e-02 -1.72267333e-02 ...
      1.94614641e-02 -6.88518360e-02 -6.84302747e-02]
    [-4.98339571e-02  5.38594611e-02  6.75405487e-02 ...
      6.59726374e-03 -2.86071394e-02  2.94821672e-02]
    [-1.70322154e-02  3.23171122e-03 -9.67009738e-03 ...
      5.85122686e-03 -2.35910937e-02  2.97874026e-02]]

   [[-1.16223255e-02  3.66859548e-02 -7.65427202e-03 ...
      1.64846610e-02 -9.62952226e-02 -9.62916389e-02]
    [-7.86703229e-02  2.86225770e-02  6.18349276e-02 ...
      1.29025616e-03 -2.02125106e-02  3.72486040e-02]
    [-1.20273829e-02  3.47230658e-02 -4.51089814e-03 ...
      2.45419908e-02 -2.62224786e-02  2.21864134e-02]]

   ...

   [[-3.19192931e-02  7.18351603e-02 -4.28153500e-02 ...
      7.03381747e-03 -1.36818856e-01 -1.56896755e-01]
    [-1.05386548e-01  4.15120684e-02  6.81234822e-02 ...
     -3.12198270e-02 -3.42080072e-02  6.01806380e-02]
    [-2.06874460e-02  4.44727913e-02  1.31856017e-02 ...
      5.60805239e-02 -1.72917321e-02 -1.07942522e-03]]

   [[-2.97704488e-02  7.49808848e-02 -3.62125337e-02 ...
     -2.06511952e-02 -1.19673520e-01 -1.37231916e-01]
    [-7.55185559e-02  4.33621854e-02  4.26110551e-02 ...
     -4.26517539e-02 -2.30262168e-02  8.32250267e-02]
    [-6.60321116e-03  3.54940891e-02  2.36980543e-02 ...
      8.46842751e-02 -2.45643649e-02  1.37924962e-02]]

   [[ 5.91525901e-03  2.44261324e-03 -3.34736444e-02 ...
     -2.78865341e-02 -6.24412596e-02 -1.12322122e-01]
    [-8.07825625e-02  3.91773060e-02 -4.91225272e-02 ...
     -3.75587940e-02 -2.48860158e-02  4.39865328e-02]
    [ 7.03450711e-03 -4.43650782e-03  1.44479815e-02 ...
      6.97280839e-02  3.27673666e-02 -1.41870379e-02]]]]], shape=(1, 20, 20, 3, 96), dtype=float32)

通过结果,我们可以看出,该tensor代表的是一个feature map,宽15,高18,3个尺寸的anchors,每一个anchor都有91种分类,座标值4个和一个置信度。

# anchorbox的形状
anchors = np.array([[17, 20], [43, 52], [66, 127], [132, 69], [116, 243], [205, 149],
                    [233, 363], [410, 216], [496, 440]], np.float32) / image_shape[0]
first_out_bbox, first_out_objectness, first_out_class_probs, first_out_pred_box = layers.Lambda(
    lambda x: yolo_boxes(x, anchors[anchor_masks[0]], 91),
    name='yolo_boxes_first_out')(x)

 我们对每一个tensor都进行运算,获取每一个anchor的座标框、置信度、分类概率、座标偏移量,这里跟Faster RCNN的Anchor机制不同的是,Faster RCNN是直接在原图像区域内获取feature map每一个像素映射的区域的中心点Anchor,再分出9种不同的形状来获取Bounding box;而YOLO是在feature map中划分单元格,来预测每一个单元格中是否包含目标区域。YOLO的单元格有3种划分的方法,每一种单元格的划分各有一个中心点。

def yolo_boxes(pred, anchors, num_classes):
    """ 最后的预测结果
    """
    # pred: (batch_size, grid, grid, anchors, (x, y, w, h, obj, ...classes))
    # 获取feature map的尺寸, 尺寸越大,能检测的目标越小,尺寸越小,能检测的目标越大
    grid_size = tf.shape(pred)[1:3]
    print('grid_size', grid_size)
    # 对tensor进行2:2:1:91的切片,前两个是anchorbox左上的座标值偏移量,再两个是anchorbox的宽高,
    # 再一个是置信度,后面的都是分类特征值
    box_xy, box_wh, objectness, class_probs = tf.split(pred, (2, 2, 1, num_classes), axis=-1)
    # 将座标值偏移量做一个0~1归一化,避免中心点在单元格之外
    box_xy = tf.sigmoid(box_xy)
    # 将取置信度做一个0~1归一化
    objectness = tf.sigmoid(objectness)
    # 获取91种分类的0~1归一化
    class_probs = layers.Softmax()(class_probs)
    # 获取座标以及宽高的预测值
    pred_box = tf.concat((box_xy, box_wh), axis=-1)  # original xywh for loss

    # 搭建一个feature map尺寸的网格,代表每个偏移量左上角的座标
    grid = tf.meshgrid(tf.range(grid_size[1]), tf.range(grid_size[0]))
    # 对该网格进行拼接,并扩展一个维度
    grid = tf.expand_dims(tf.stack(grid, axis=-1), axis=2)  # [gx, gy, 1, 2]
    # 获取在feature map范围内的座标值(左上角座标+偏移量),再归一化到0~1之间
    box_xy = (box_xy + tf.cast(grid, tf.float32)) / tf.cast(grid_size, tf.float32)
    # 根据anchor的形状比来获取宽高,划分单元格
    box_wh = tf.exp(box_wh) * anchors
    print('box_wh', box_wh)
    # 获取中心点的座标
    box_x1y1 = box_xy - box_wh / 2
    box_x2y2 = box_xy + box_wh / 2
    # 根据中心点座标处理成左上角和右下角座标的形式
    x1, y1 = tf.split(box_x1y1, (1, 1), axis=-1)
    x2, y2 = tf.split(box_x2y2, (1, 1), axis=-1)
    # 限制在feature map大小范围内
    x1 = tf.minimum(tf.maximum(x1, 0.), image_shape[1])
    y1 = tf.minimum(tf.maximum(y1, 0.), image_shape[0])
    x2 = tf.minimum(tf.maximum(x2, 0.), image_shape[1])
    y2 = tf.minimum(tf.maximum(y2, 0.), image_shape[0])
    # 为计算IoU拼接成一个box
    bbox = tf.concat([x1, y1, x2, y2], axis=-1)

    return bbox, objectness, class_probs, pred_box

运行结果

grid_size tf.Tensor([20 20], shape=(2,), dtype=int32)
grid_size tf.Tensor([40 40], shape=(2,), dtype=int32)
grid_size tf.Tensor([80 80], shape=(2,), dtype=int32)
box_wh tf.Tensor(
[[[[[6.39551878e-01 3.97852451e-01]
    [7.99882174e-01 3.92906696e-01]
    [6.22583508e-01 1.02322054e+00]]]]]

然后是第二种尺度的确立

feature_maps = (concat_output, x2)
if isinstance(feature_maps, tuple):
    x, x_skip = feature_maps[0], feature_maps[1]

    # concat with skip connection
    x = darknet._darknet_conv(x, 256, 1)
    # 将x3进行上采样,使得尺寸跟x2相同
    x = layers.UpSampling2D(2)(x)
    # 拼接x2和上采样后的x3
    x = layers.Concatenate()([x, x_skip])
else:
    x = feature_maps
# 继续提取特征
x = darknet._darknet_conv(x, 256, 1)
x = darknet._darknet_conv(x, 256 * 2, 3)
x = darknet._darknet_conv(x, 256, 1)
x = darknet._darknet_conv(x, 256 * 2, 3)
x = darknet._darknet_conv(x, 256, 1)
# 第二个连接点
concat_output = x

x = darknet._darknet_conv(x, 256 * 2, 3)
# [batch, h, w, num_anchors * (num_class + 5)]
# 此处不使用批归一化和激活函数,91为分类的类别数,经过1*1的卷积变为通道数为288
x = darknet._darknet_conv(x, num_anchors * (91 + 5), 1, batch_norm=False)
num_anchors = len(anchor_masks[1])
# [batch, h, w, num_anchors, (num_class + 5)]
# 获取每一个Bounding box的座标偏移,宽高,置信度,91种分类值
x = layers.Lambda(lambda x: tf.reshape(x, (-1, tf.shape(x)[1], tf.shape(x)[2],
                                                    num_anchors, 91 + 5)))(x)
# 对每一个tensor都进行运算, 获取每一个anchor的座标框、置信度、分类概率、座标偏移量
second_out_bbox, second_out_objectness, second_out_class_probs, second_out_pred_box = layers.Lambda(
    lambda x: yolo_boxes(x, anchors[anchor_masks[1]], 91),
    name='yolo_boxes_second_out')(x)

然后是第三种尺度的确立

feature_maps = (concat_output, x1)
if isinstance(feature_maps, tuple):
    x, x_skip = feature_maps[0], feature_maps[1]

    # concat with skip connection
    x = darknet._darknet_conv(x, 128, 1)
    # 将x3进行上采样,使得尺寸跟x1相同
    x = layers.UpSampling2D(2)(x)
    # 拼接x1和上采样后的x3
    x = layers.Concatenate()([x, x_skip])
else:
    x = feature_maps
# 继续提取特征
x = darknet._darknet_conv(x, 128, 1)
x = darknet._darknet_conv(x, 128 * 2, 3)
x = darknet._darknet_conv(x, 128, 1)
x = darknet._darknet_conv(x, 128 * 2, 3)
x = darknet._darknet_conv(x, 128, 1)
# 第三个连接点
concat_output = x

x = darknet._darknet_conv(x, 128 * 2, 3)
# [batch, h, w, num_anchors * (num_class + 5)]
# 此处不使用批归一化和激活函数,91为分类的类别数,经过1*1的卷积变为通道数为288
x = darknet._darknet_conv(x, num_anchors * (91 + 5), 1, batch_norm=False)
num_anchors = len(anchor_masks[2])
# [batch, h, w, num_anchors, (num_class + 5)]
# 获取每一个Bounding box的座标偏移,宽高,置信度,91种分类值
x = layers.Lambda(lambda x: tf.reshape(x, (-1, tf.shape(x)[1], tf.shape(x)[2],
                                           num_anchors, 91 + 5)))(x)
# 对每一个tensor都进行运算, 获取每一个anchor的座标框、置信度、分类概率、座标偏移量
third_out_bbox, third_out_objectness, third_out_class_probs, third_out_pred_box = layers.Lambda(
    lambda x: yolo_boxes(x, anchors[anchor_masks[2]], 91),
    name='yolo_boxes_third_out')(x)

is_training = True
if is_training:
    model = models.Model(inputs=inputs, outputs=[
        [first_out_bbox, first_out_objectness, first_out_class_probs, first_out_pred_box],
        [second_out_bbox, second_out_objectness, second_out_class_probs, second_out_pred_box],
        [third_out_bbox, third_out_objectness, third_out_class_probs, third_out_pred_box]
    ])
    print(model.summary())

为了此处能够打印网络结构,我们将之前的传入改为

x1, x2, x3 = darknet.build_darknet(inputs, "darknet")

运行结果

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_images (InputLayer)       [(None, 640, 640, 3) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 640, 640, 32) 864         input_images[0][0]               
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 640, 640, 32) 128         conv2d[0][0]                     
__________________________________________________________________________________________________
leaky_re_lu (LeakyReLU)         (None, 640, 640, 32) 0           batch_normalization[0][0]        
__________________________________________________________________________________________________
zero_padding2d (ZeroPadding2D)  (None, 641, 641, 32) 0           leaky_re_lu[0][0]                
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 320, 320, 64) 18432       zero_padding2d[0][0]             
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 320, 320, 64) 256         conv2d_1[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_1 (LeakyReLU)       (None, 320, 320, 64) 0           batch_normalization_1[0][0]      
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 320, 320, 32) 2048        leaky_re_lu_1[0][0]              
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 320, 320, 32) 128         conv2d_2[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_2 (LeakyReLU)       (None, 320, 320, 32) 0           batch_normalization_2[0][0]      
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 320, 320, 64) 18432       leaky_re_lu_2[0][0]              
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 320, 320, 64) 256         conv2d_3[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_3 (LeakyReLU)       (None, 320, 320, 64) 0           batch_normalization_3[0][0]      
__________________________________________________________________________________________________
add (Add)                       (None, 320, 320, 64) 0           leaky_re_lu_1[0][0]              
                                                                 leaky_re_lu_3[0][0]              
__________________________________________________________________________________________________
zero_padding2d_1 (ZeroPadding2D (None, 321, 321, 64) 0           add[0][0]                        
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 160, 160, 128 73728       zero_padding2d_1[0][0]           
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 160, 160, 128 512         conv2d_4[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_4 (LeakyReLU)       (None, 160, 160, 128 0           batch_normalization_4[0][0]      
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 160, 160, 64) 8192        leaky_re_lu_4[0][0]              
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 160, 160, 64) 256         conv2d_5[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_5 (LeakyReLU)       (None, 160, 160, 64) 0           batch_normalization_5[0][0]      
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 160, 160, 128 73728       leaky_re_lu_5[0][0]              
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 160, 160, 128 512         conv2d_6[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_6 (LeakyReLU)       (None, 160, 160, 128 0           batch_normalization_6[0][0]      
__________________________________________________________________________________________________
add_1 (Add)                     (None, 160, 160, 128 0           leaky_re_lu_4[0][0]              
                                                                 leaky_re_lu_6[0][0]              
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 160, 160, 64) 8192        add_1[0][0]                      
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 160, 160, 64) 256         conv2d_7[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_7 (LeakyReLU)       (None, 160, 160, 64) 0           batch_normalization_7[0][0]      
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 160, 160, 128 73728       leaky_re_lu_7[0][0]              
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 160, 160, 128 512         conv2d_8[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_8 (LeakyReLU)       (None, 160, 160, 128 0           batch_normalization_8[0][0]      
__________________________________________________________________________________________________
add_2 (Add)                     (None, 160, 160, 128 0           add_1[0][0]                      
                                                                 leaky_re_lu_8[0][0]              
__________________________________________________________________________________________________
zero_padding2d_2 (ZeroPadding2D (None, 161, 161, 128 0           add_2[0][0]                      
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 80, 80, 256)  294912      zero_padding2d_2[0][0]           
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 80, 80, 256)  1024        conv2d_9[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_9 (LeakyReLU)       (None, 80, 80, 256)  0           batch_normalization_9[0][0]      
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 80, 80, 128)  32768       leaky_re_lu_9[0][0]              
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 80, 80, 128)  512         conv2d_10[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_10 (LeakyReLU)      (None, 80, 80, 128)  0           batch_normalization_10[0][0]     
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 80, 80, 256)  294912      leaky_re_lu_10[0][0]             
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 80, 80, 256)  1024        conv2d_11[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_11 (LeakyReLU)      (None, 80, 80, 256)  0           batch_normalization_11[0][0]     
__________________________________________________________________________________________________
add_3 (Add)                     (None, 80, 80, 256)  0           leaky_re_lu_9[0][0]              
                                                                 leaky_re_lu_11[0][0]             
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 80, 80, 128)  32768       add_3[0][0]                      
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, 80, 80, 128)  512         conv2d_12[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_12 (LeakyReLU)      (None, 80, 80, 128)  0           batch_normalization_12[0][0]     
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 80, 80, 256)  294912      leaky_re_lu_12[0][0]             
__________________________________________________________________________________________________
batch_normalization_13 (BatchNo (None, 80, 80, 256)  1024        conv2d_13[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_13 (LeakyReLU)      (None, 80, 80, 256)  0           batch_normalization_13[0][0]     
__________________________________________________________________________________________________
add_4 (Add)                     (None, 80, 80, 256)  0           add_3[0][0]                      
                                                                 leaky_re_lu_13[0][0]             
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 80, 80, 128)  32768       add_4[0][0]                      
__________________________________________________________________________________________________
batch_normalization_14 (BatchNo (None, 80, 80, 128)  512         conv2d_14[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_14 (LeakyReLU)      (None, 80, 80, 128)  0           batch_normalization_14[0][0]     
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 80, 80, 256)  294912      leaky_re_lu_14[0][0]             
__________________________________________________________________________________________________
batch_normalization_15 (BatchNo (None, 80, 80, 256)  1024        conv2d_15[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_15 (LeakyReLU)      (None, 80, 80, 256)  0           batch_normalization_15[0][0]     
__________________________________________________________________________________________________
add_5 (Add)                     (None, 80, 80, 256)  0           add_4[0][0]                      
                                                                 leaky_re_lu_15[0][0]             
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 80, 80, 128)  32768       add_5[0][0]                      
__________________________________________________________________________________________________
batch_normalization_16 (BatchNo (None, 80, 80, 128)  512         conv2d_16[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_16 (LeakyReLU)      (None, 80, 80, 128)  0           batch_normalization_16[0][0]     
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 80, 80, 256)  294912      leaky_re_lu_16[0][0]             
__________________________________________________________________________________________________
batch_normalization_17 (BatchNo (None, 80, 80, 256)  1024        conv2d_17[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_17 (LeakyReLU)      (None, 80, 80, 256)  0           batch_normalization_17[0][0]     
__________________________________________________________________________________________________
add_6 (Add)                     (None, 80, 80, 256)  0           add_5[0][0]                      
                                                                 leaky_re_lu_17[0][0]             
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 80, 80, 128)  32768       add_6[0][0]                      
__________________________________________________________________________________________________
batch_normalization_18 (BatchNo (None, 80, 80, 128)  512         conv2d_18[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_18 (LeakyReLU)      (None, 80, 80, 128)  0           batch_normalization_18[0][0]     
__________________________________________________________________________________________________
conv2d_19 (Conv2D)              (None, 80, 80, 256)  294912      leaky_re_lu_18[0][0]             
__________________________________________________________________________________________________
batch_normalization_19 (BatchNo (None, 80, 80, 256)  1024        conv2d_19[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_19 (LeakyReLU)      (None, 80, 80, 256)  0           batch_normalization_19[0][0]     
__________________________________________________________________________________________________
add_7 (Add)                     (None, 80, 80, 256)  0           add_6[0][0]                      
                                                                 leaky_re_lu_19[0][0]             
__________________________________________________________________________________________________
conv2d_20 (Conv2D)              (None, 80, 80, 128)  32768       add_7[0][0]                      
__________________________________________________________________________________________________
batch_normalization_20 (BatchNo (None, 80, 80, 128)  512         conv2d_20[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_20 (LeakyReLU)      (None, 80, 80, 128)  0           batch_normalization_20[0][0]     
__________________________________________________________________________________________________
conv2d_21 (Conv2D)              (None, 80, 80, 256)  294912      leaky_re_lu_20[0][0]             
__________________________________________________________________________________________________
batch_normalization_21 (BatchNo (None, 80, 80, 256)  1024        conv2d_21[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_21 (LeakyReLU)      (None, 80, 80, 256)  0           batch_normalization_21[0][0]     
__________________________________________________________________________________________________
add_8 (Add)                     (None, 80, 80, 256)  0           add_7[0][0]                      
                                                                 leaky_re_lu_21[0][0]             
__________________________________________________________________________________________________
conv2d_22 (Conv2D)              (None, 80, 80, 128)  32768       add_8[0][0]                      
__________________________________________________________________________________________________
batch_normalization_22 (BatchNo (None, 80, 80, 128)  512         conv2d_22[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_22 (LeakyReLU)      (None, 80, 80, 128)  0           batch_normalization_22[0][0]     
__________________________________________________________________________________________________
conv2d_23 (Conv2D)              (None, 80, 80, 256)  294912      leaky_re_lu_22[0][0]             
__________________________________________________________________________________________________
batch_normalization_23 (BatchNo (None, 80, 80, 256)  1024        conv2d_23[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_23 (LeakyReLU)      (None, 80, 80, 256)  0           batch_normalization_23[0][0]     
__________________________________________________________________________________________________
add_9 (Add)                     (None, 80, 80, 256)  0           add_8[0][0]                      
                                                                 leaky_re_lu_23[0][0]             
__________________________________________________________________________________________________
conv2d_24 (Conv2D)              (None, 80, 80, 128)  32768       add_9[0][0]                      
__________________________________________________________________________________________________
batch_normalization_24 (BatchNo (None, 80, 80, 128)  512         conv2d_24[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_24 (LeakyReLU)      (None, 80, 80, 128)  0           batch_normalization_24[0][0]     
__________________________________________________________________________________________________
conv2d_25 (Conv2D)              (None, 80, 80, 256)  294912      leaky_re_lu_24[0][0]             
__________________________________________________________________________________________________
batch_normalization_25 (BatchNo (None, 80, 80, 256)  1024        conv2d_25[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_25 (LeakyReLU)      (None, 80, 80, 256)  0           batch_normalization_25[0][0]     
__________________________________________________________________________________________________
add_10 (Add)                    (None, 80, 80, 256)  0           add_9[0][0]                      
                                                                 leaky_re_lu_25[0][0]             
__________________________________________________________________________________________________
zero_padding2d_3 (ZeroPadding2D (None, 81, 81, 256)  0           add_10[0][0]                     
__________________________________________________________________________________________________
conv2d_26 (Conv2D)              (None, 40, 40, 512)  1179648     zero_padding2d_3[0][0]           
__________________________________________________________________________________________________
batch_normalization_26 (BatchNo (None, 40, 40, 512)  2048        conv2d_26[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_26 (LeakyReLU)      (None, 40, 40, 512)  0           batch_normalization_26[0][0]     
__________________________________________________________________________________________________
conv2d_27 (Conv2D)              (None, 40, 40, 256)  131072      leaky_re_lu_26[0][0]             
__________________________________________________________________________________________________
batch_normalization_27 (BatchNo (None, 40, 40, 256)  1024        conv2d_27[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_27 (LeakyReLU)      (None, 40, 40, 256)  0           batch_normalization_27[0][0]     
__________________________________________________________________________________________________
conv2d_28 (Conv2D)              (None, 40, 40, 512)  1179648     leaky_re_lu_27[0][0]             
__________________________________________________________________________________________________
batch_normalization_28 (BatchNo (None, 40, 40, 512)  2048        conv2d_28[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_28 (LeakyReLU)      (None, 40, 40, 512)  0           batch_normalization_28[0][0]     
__________________________________________________________________________________________________
add_11 (Add)                    (None, 40, 40, 512)  0           leaky_re_lu_26[0][0]             
                                                                 leaky_re_lu_28[0][0]             
__________________________________________________________________________________________________
conv2d_29 (Conv2D)              (None, 40, 40, 256)  131072      add_11[0][0]                     
__________________________________________________________________________________________________
batch_normalization_29 (BatchNo (None, 40, 40, 256)  1024        conv2d_29[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_29 (LeakyReLU)      (None, 40, 40, 256)  0           batch_normalization_29[0][0]     
__________________________________________________________________________________________________
conv2d_30 (Conv2D)              (None, 40, 40, 512)  1179648     leaky_re_lu_29[0][0]             
__________________________________________________________________________________________________
batch_normalization_30 (BatchNo (None, 40, 40, 512)  2048        conv2d_30[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_30 (LeakyReLU)      (None, 40, 40, 512)  0           batch_normalization_30[0][0]     
__________________________________________________________________________________________________
add_12 (Add)                    (None, 40, 40, 512)  0           add_11[0][0]                     
                                                                 leaky_re_lu_30[0][0]             
__________________________________________________________________________________________________
conv2d_31 (Conv2D)              (None, 40, 40, 256)  131072      add_12[0][0]                     
__________________________________________________________________________________________________
batch_normalization_31 (BatchNo (None, 40, 40, 256)  1024        conv2d_31[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_31 (LeakyReLU)      (None, 40, 40, 256)  0           batch_normalization_31[0][0]     
__________________________________________________________________________________________________
conv2d_32 (Conv2D)              (None, 40, 40, 512)  1179648     leaky_re_lu_31[0][0]             
__________________________________________________________________________________________________
batch_normalization_32 (BatchNo (None, 40, 40, 512)  2048        conv2d_32[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_32 (LeakyReLU)      (None, 40, 40, 512)  0           batch_normalization_32[0][0]     
__________________________________________________________________________________________________
add_13 (Add)                    (None, 40, 40, 512)  0           add_12[0][0]                     
                                                                 leaky_re_lu_32[0][0]             
__________________________________________________________________________________________________
conv2d_33 (Conv2D)              (None, 40, 40, 256)  131072      add_13[0][0]                     
__________________________________________________________________________________________________
batch_normalization_33 (BatchNo (None, 40, 40, 256)  1024        conv2d_33[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_33 (LeakyReLU)      (None, 40, 40, 256)  0           batch_normalization_33[0][0]     
__________________________________________________________________________________________________
conv2d_34 (Conv2D)              (None, 40, 40, 512)  1179648     leaky_re_lu_33[0][0]             
__________________________________________________________________________________________________
batch_normalization_34 (BatchNo (None, 40, 40, 512)  2048        conv2d_34[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_34 (LeakyReLU)      (None, 40, 40, 512)  0           batch_normalization_34[0][0]     
__________________________________________________________________________________________________
add_14 (Add)                    (None, 40, 40, 512)  0           add_13[0][0]                     
                                                                 leaky_re_lu_34[0][0]             
__________________________________________________________________________________________________
conv2d_35 (Conv2D)              (None, 40, 40, 256)  131072      add_14[0][0]                     
__________________________________________________________________________________________________
batch_normalization_35 (BatchNo (None, 40, 40, 256)  1024        conv2d_35[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_35 (LeakyReLU)      (None, 40, 40, 256)  0           batch_normalization_35[0][0]     
__________________________________________________________________________________________________
conv2d_36 (Conv2D)              (None, 40, 40, 512)  1179648     leaky_re_lu_35[0][0]             
__________________________________________________________________________________________________
batch_normalization_36 (BatchNo (None, 40, 40, 512)  2048        conv2d_36[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_36 (LeakyReLU)      (None, 40, 40, 512)  0           batch_normalization_36[0][0]     
__________________________________________________________________________________________________
add_15 (Add)                    (None, 40, 40, 512)  0           add_14[0][0]                     
                                                                 leaky_re_lu_36[0][0]             
__________________________________________________________________________________________________
conv2d_37 (Conv2D)              (None, 40, 40, 256)  131072      add_15[0][0]                     
__________________________________________________________________________________________________
batch_normalization_37 (BatchNo (None, 40, 40, 256)  1024        conv2d_37[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_37 (LeakyReLU)      (None, 40, 40, 256)  0           batch_normalization_37[0][0]     
__________________________________________________________________________________________________
conv2d_38 (Conv2D)              (None, 40, 40, 512)  1179648     leaky_re_lu_37[0][0]             
__________________________________________________________________________________________________
batch_normalization_38 (BatchNo (None, 40, 40, 512)  2048        conv2d_38[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_38 (LeakyReLU)      (None, 40, 40, 512)  0           batch_normalization_38[0][0]     
__________________________________________________________________________________________________
add_16 (Add)                    (None, 40, 40, 512)  0           add_15[0][0]                     
                                                                 leaky_re_lu_38[0][0]             
__________________________________________________________________________________________________
conv2d_39 (Conv2D)              (None, 40, 40, 256)  131072      add_16[0][0]                     
__________________________________________________________________________________________________
batch_normalization_39 (BatchNo (None, 40, 40, 256)  1024        conv2d_39[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_39 (LeakyReLU)      (None, 40, 40, 256)  0           batch_normalization_39[0][0]     
__________________________________________________________________________________________________
conv2d_40 (Conv2D)              (None, 40, 40, 512)  1179648     leaky_re_lu_39[0][0]             
__________________________________________________________________________________________________
batch_normalization_40 (BatchNo (None, 40, 40, 512)  2048        conv2d_40[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_40 (LeakyReLU)      (None, 40, 40, 512)  0           batch_normalization_40[0][0]     
__________________________________________________________________________________________________
add_17 (Add)                    (None, 40, 40, 512)  0           add_16[0][0]                     
                                                                 leaky_re_lu_40[0][0]             
__________________________________________________________________________________________________
conv2d_41 (Conv2D)              (None, 40, 40, 256)  131072      add_17[0][0]                     
__________________________________________________________________________________________________
batch_normalization_41 (BatchNo (None, 40, 40, 256)  1024        conv2d_41[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_41 (LeakyReLU)      (None, 40, 40, 256)  0           batch_normalization_41[0][0]     
__________________________________________________________________________________________________
conv2d_42 (Conv2D)              (None, 40, 40, 512)  1179648     leaky_re_lu_41[0][0]             
__________________________________________________________________________________________________
batch_normalization_42 (BatchNo (None, 40, 40, 512)  2048        conv2d_42[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_42 (LeakyReLU)      (None, 40, 40, 512)  0           batch_normalization_42[0][0]     
__________________________________________________________________________________________________
add_18 (Add)                    (None, 40, 40, 512)  0           add_17[0][0]                     
                                                                 leaky_re_lu_42[0][0]             
__________________________________________________________________________________________________
zero_padding2d_4 (ZeroPadding2D (None, 41, 41, 512)  0           add_18[0][0]                     
__________________________________________________________________________________________________
conv2d_43 (Conv2D)              (None, 20, 20, 1024) 4718592     zero_padding2d_4[0][0]           
__________________________________________________________________________________________________
batch_normalization_43 (BatchNo (None, 20, 20, 1024) 4096        conv2d_43[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_43 (LeakyReLU)      (None, 20, 20, 1024) 0           batch_normalization_43[0][0]     
__________________________________________________________________________________________________
conv2d_44 (Conv2D)              (None, 20, 20, 512)  524288      leaky_re_lu_43[0][0]             
__________________________________________________________________________________________________
batch_normalization_44 (BatchNo (None, 20, 20, 512)  2048        conv2d_44[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_44 (LeakyReLU)      (None, 20, 20, 512)  0           batch_normalization_44[0][0]     
__________________________________________________________________________________________________
conv2d_45 (Conv2D)              (None, 20, 20, 1024) 4718592     leaky_re_lu_44[0][0]             
__________________________________________________________________________________________________
batch_normalization_45 (BatchNo (None, 20, 20, 1024) 4096        conv2d_45[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_45 (LeakyReLU)      (None, 20, 20, 1024) 0           batch_normalization_45[0][0]     
__________________________________________________________________________________________________
add_19 (Add)                    (None, 20, 20, 1024) 0           leaky_re_lu_43[0][0]             
                                                                 leaky_re_lu_45[0][0]             
__________________________________________________________________________________________________
conv2d_46 (Conv2D)              (None, 20, 20, 512)  524288      add_19[0][0]                     
__________________________________________________________________________________________________
batch_normalization_46 (BatchNo (None, 20, 20, 512)  2048        conv2d_46[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_46 (LeakyReLU)      (None, 20, 20, 512)  0           batch_normalization_46[0][0]     
__________________________________________________________________________________________________
conv2d_47 (Conv2D)              (None, 20, 20, 1024) 4718592     leaky_re_lu_46[0][0]             
__________________________________________________________________________________________________
batch_normalization_47 (BatchNo (None, 20, 20, 1024) 4096        conv2d_47[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_47 (LeakyReLU)      (None, 20, 20, 1024) 0           batch_normalization_47[0][0]     
__________________________________________________________________________________________________
add_20 (Add)                    (None, 20, 20, 1024) 0           add_19[0][0]                     
                                                                 leaky_re_lu_47[0][0]             
__________________________________________________________________________________________________
conv2d_48 (Conv2D)              (None, 20, 20, 512)  524288      add_20[0][0]                     
__________________________________________________________________________________________________
batch_normalization_48 (BatchNo (None, 20, 20, 512)  2048        conv2d_48[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_48 (LeakyReLU)      (None, 20, 20, 512)  0           batch_normalization_48[0][0]     
__________________________________________________________________________________________________
conv2d_49 (Conv2D)              (None, 20, 20, 1024) 4718592     leaky_re_lu_48[0][0]             
__________________________________________________________________________________________________
batch_normalization_49 (BatchNo (None, 20, 20, 1024) 4096        conv2d_49[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_49 (LeakyReLU)      (None, 20, 20, 1024) 0           batch_normalization_49[0][0]     
__________________________________________________________________________________________________
add_21 (Add)                    (None, 20, 20, 1024) 0           add_20[0][0]                     
                                                                 leaky_re_lu_49[0][0]             
__________________________________________________________________________________________________
conv2d_50 (Conv2D)              (None, 20, 20, 512)  524288      add_21[0][0]                     
__________________________________________________________________________________________________
batch_normalization_50 (BatchNo (None, 20, 20, 512)  2048        conv2d_50[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_50 (LeakyReLU)      (None, 20, 20, 512)  0           batch_normalization_50[0][0]     
__________________________________________________________________________________________________
conv2d_51 (Conv2D)              (None, 20, 20, 1024) 4718592     leaky_re_lu_50[0][0]             
__________________________________________________________________________________________________
batch_normalization_51 (BatchNo (None, 20, 20, 1024) 4096        conv2d_51[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_51 (LeakyReLU)      (None, 20, 20, 1024) 0           batch_normalization_51[0][0]     
__________________________________________________________________________________________________
add_22 (Add)                    (None, 20, 20, 1024) 0           add_21[0][0]                     
                                                                 leaky_re_lu_51[0][0]             
__________________________________________________________________________________________________
conv2d_52 (Conv2D)              (None, 20, 20, 512)  524288      add_22[0][0]                     
__________________________________________________________________________________________________
batch_normalization_52 (BatchNo (None, 20, 20, 512)  2048        conv2d_52[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_52 (LeakyReLU)      (None, 20, 20, 512)  0           batch_normalization_52[0][0]     
__________________________________________________________________________________________________
conv2d_53 (Conv2D)              (None, 20, 20, 1024) 4718592     leaky_re_lu_52[0][0]             
__________________________________________________________________________________________________
batch_normalization_53 (BatchNo (None, 20, 20, 1024) 4096        conv2d_53[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_53 (LeakyReLU)      (None, 20, 20, 1024) 0           batch_normalization_53[0][0]     
__________________________________________________________________________________________________
conv2d_54 (Conv2D)              (None, 20, 20, 512)  524288      leaky_re_lu_53[0][0]             
__________________________________________________________________________________________________
batch_normalization_54 (BatchNo (None, 20, 20, 512)  2048        conv2d_54[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_54 (LeakyReLU)      (None, 20, 20, 512)  0           batch_normalization_54[0][0]     
__________________________________________________________________________________________________
conv2d_55 (Conv2D)              (None, 20, 20, 1024) 4718592     leaky_re_lu_54[0][0]             
__________________________________________________________________________________________________
batch_normalization_55 (BatchNo (None, 20, 20, 1024) 4096        conv2d_55[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_55 (LeakyReLU)      (None, 20, 20, 1024) 0           batch_normalization_55[0][0]     
__________________________________________________________________________________________________
conv2d_56 (Conv2D)              (None, 20, 20, 512)  524288      leaky_re_lu_55[0][0]             
__________________________________________________________________________________________________
batch_normalization_56 (BatchNo (None, 20, 20, 512)  2048        conv2d_56[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_56 (LeakyReLU)      (None, 20, 20, 512)  0           batch_normalization_56[0][0]     
__________________________________________________________________________________________________
conv2d_59 (Conv2D)              (None, 20, 20, 256)  131072      leaky_re_lu_56[0][0]             
__________________________________________________________________________________________________
batch_normalization_58 (BatchNo (None, 20, 20, 256)  1024        conv2d_59[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_58 (LeakyReLU)      (None, 20, 20, 256)  0           batch_normalization_58[0][0]     
__________________________________________________________________________________________________
up_sampling2d (UpSampling2D)    (None, 40, 40, 256)  0           leaky_re_lu_58[0][0]             
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 40, 40, 768)  0           up_sampling2d[0][0]              
                                                                 add_18[0][0]                     
__________________________________________________________________________________________________
conv2d_60 (Conv2D)              (None, 40, 40, 256)  196608      concatenate[0][0]                
__________________________________________________________________________________________________
batch_normalization_59 (BatchNo (None, 40, 40, 256)  1024        conv2d_60[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_59 (LeakyReLU)      (None, 40, 40, 256)  0           batch_normalization_59[0][0]     
__________________________________________________________________________________________________
conv2d_61 (Conv2D)              (None, 40, 40, 512)  1179648     leaky_re_lu_59[0][0]             
__________________________________________________________________________________________________
batch_normalization_60 (BatchNo (None, 40, 40, 512)  2048        conv2d_61[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_60 (LeakyReLU)      (None, 40, 40, 512)  0           batch_normalization_60[0][0]     
__________________________________________________________________________________________________
conv2d_62 (Conv2D)              (None, 40, 40, 256)  131072      leaky_re_lu_60[0][0]             
__________________________________________________________________________________________________
batch_normalization_61 (BatchNo (None, 40, 40, 256)  1024        conv2d_62[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_61 (LeakyReLU)      (None, 40, 40, 256)  0           batch_normalization_61[0][0]     
__________________________________________________________________________________________________
conv2d_63 (Conv2D)              (None, 40, 40, 512)  1179648     leaky_re_lu_61[0][0]             
__________________________________________________________________________________________________
batch_normalization_62 (BatchNo (None, 40, 40, 512)  2048        conv2d_63[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_62 (LeakyReLU)      (None, 40, 40, 512)  0           batch_normalization_62[0][0]     
__________________________________________________________________________________________________
conv2d_64 (Conv2D)              (None, 40, 40, 256)  131072      leaky_re_lu_62[0][0]             
__________________________________________________________________________________________________
batch_normalization_63 (BatchNo (None, 40, 40, 256)  1024        conv2d_64[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_63 (LeakyReLU)      (None, 40, 40, 256)  0           batch_normalization_63[0][0]     
__________________________________________________________________________________________________
conv2d_67 (Conv2D)              (None, 40, 40, 128)  32768       leaky_re_lu_63[0][0]             
__________________________________________________________________________________________________
batch_normalization_65 (BatchNo (None, 40, 40, 128)  512         conv2d_67[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_65 (LeakyReLU)      (None, 40, 40, 128)  0           batch_normalization_65[0][0]     
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 80, 80, 128)  0           leaky_re_lu_65[0][0]             
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 80, 80, 384)  0           up_sampling2d_1[0][0]            
                                                                 add_10[0][0]                     
__________________________________________________________________________________________________
conv2d_68 (Conv2D)              (None, 80, 80, 128)  49152       concatenate_1[0][0]              
__________________________________________________________________________________________________
batch_normalization_66 (BatchNo (None, 80, 80, 128)  512         conv2d_68[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_66 (LeakyReLU)      (None, 80, 80, 128)  0           batch_normalization_66[0][0]     
__________________________________________________________________________________________________
conv2d_69 (Conv2D)              (None, 80, 80, 256)  294912      leaky_re_lu_66[0][0]             
__________________________________________________________________________________________________
batch_normalization_67 (BatchNo (None, 80, 80, 256)  1024        conv2d_69[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_67 (LeakyReLU)      (None, 80, 80, 256)  0           batch_normalization_67[0][0]     
__________________________________________________________________________________________________
conv2d_70 (Conv2D)              (None, 80, 80, 128)  32768       leaky_re_lu_67[0][0]             
__________________________________________________________________________________________________
batch_normalization_68 (BatchNo (None, 80, 80, 128)  512         conv2d_70[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_68 (LeakyReLU)      (None, 80, 80, 128)  0           batch_normalization_68[0][0]     
__________________________________________________________________________________________________
conv2d_71 (Conv2D)              (None, 80, 80, 256)  294912      leaky_re_lu_68[0][0]             
__________________________________________________________________________________________________
batch_normalization_69 (BatchNo (None, 80, 80, 256)  1024        conv2d_71[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_69 (LeakyReLU)      (None, 80, 80, 256)  0           batch_normalization_69[0][0]     
__________________________________________________________________________________________________
conv2d_72 (Conv2D)              (None, 80, 80, 128)  32768       leaky_re_lu_69[0][0]             
__________________________________________________________________________________________________
batch_normalization_70 (BatchNo (None, 80, 80, 128)  512         conv2d_72[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_70 (LeakyReLU)      (None, 80, 80, 128)  0           batch_normalization_70[0][0]     
__________________________________________________________________________________________________
conv2d_57 (Conv2D)              (None, 20, 20, 1024) 4718592     leaky_re_lu_56[0][0]             
__________________________________________________________________________________________________
conv2d_65 (Conv2D)              (None, 40, 40, 512)  1179648     leaky_re_lu_63[0][0]             
__________________________________________________________________________________________________
conv2d_73 (Conv2D)              (None, 80, 80, 256)  294912      leaky_re_lu_70[0][0]             
__________________________________________________________________________________________________
batch_normalization_57 (BatchNo (None, 20, 20, 1024) 4096        conv2d_57[0][0]                  
__________________________________________________________________________________________________
batch_normalization_64 (BatchNo (None, 40, 40, 512)  2048        conv2d_65[0][0]                  
__________________________________________________________________________________________________
batch_normalization_71 (BatchNo (None, 80, 80, 256)  1024        conv2d_73[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_57 (LeakyReLU)      (None, 20, 20, 1024) 0           batch_normalization_57[0][0]     
__________________________________________________________________________________________________
leaky_re_lu_64 (LeakyReLU)      (None, 40, 40, 512)  0           batch_normalization_64[0][0]     
__________________________________________________________________________________________________
leaky_re_lu_71 (LeakyReLU)      (None, 80, 80, 256)  0           batch_normalization_71[0][0]     
__________________________________________________________________________________________________
conv2d_58 (Conv2D)              (None, 20, 20, 288)  295200      leaky_re_lu_57[0][0]             
__________________________________________________________________________________________________
conv2d_66 (Conv2D)              (None, 40, 40, 288)  147744      leaky_re_lu_64[0][0]             
__________________________________________________________________________________________________
conv2d_74 (Conv2D)              (None, 80, 80, 288)  74016       leaky_re_lu_71[0][0]             
__________________________________________________________________________________________________
lambda (Lambda)                 (None, None, None, 3 0           conv2d_58[0][0]                  
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, None, None, 3 0           conv2d_66[0][0]                  
__________________________________________________________________________________________________
lambda_2 (Lambda)               (None, None, None, 3 0           conv2d_74[0][0]                  
__________________________________________________________________________________________________
yolo_boxes_first_out (Lambda)   ((None, None, None,  0           lambda[0][0]                     
__________________________________________________________________________________________________
yolo_boxes_second_out (Lambda)  ((None, None, None,  0           lambda_1[0][0]                   
__________________________________________________________________________________________________
yolo_boxes_third_out (Lambda)   ((None, None, None,  0           lambda_2[0][0]                   
==================================================================================================
Total params: 62,060,992
Trainable params: 62,008,384
Non-trainable params: 52,608
__________________________________________________________________________________________________
None

现在我们对各层的输出进行非极大值抑制。我们先将搭建网络的输入改为图片

x1, x2, x3 = darknet.build_darknet(img, "darknet")
else:
    outputs = layers.Lambda(lambda x: yolo_nms(x, 91),
                            name='yolo_nms')([
        # boxes_0[:3], boxes_1[:3], boxes_2[:3]
        [first_out_bbox, first_out_objectness, first_out_class_probs],
        [second_out_bbox, second_out_objectness, second_out_class_probs],
        [third_out_bbox, third_out_objectness, third_out_class_probs]
    ])
    model = models.Model(inputs=inputs, outputs=outputs)
    print(model.summary())

yolo_nms代码如下

def yolo_nms(yolo_pred, num_class):
    """ 对边框做非极大抑制
    """
    boxes, objectness, class_probs = [], [], []

    # pred: [bbox, objectness, class_probs]
    # 堆叠三个特征点的所有边框
    for pred in yolo_pred:
        # boxes: [batch, -1, 4]
        boxes.append(tf.reshape(pred[0], (tf.shape(pred[0])[0], -1, tf.shape(pred[0])[-1])))
        # objectness: [batch, -1, 1]
        objectness.append(tf.reshape(pred[1], (tf.shape(pred[1])[0], -1, tf.shape(pred[1])[-1])))
        # class_probs: [batch, -1, num_classes]
        class_probs.append(tf.reshape(pred[2], (tf.shape(pred[2])[0], -1, tf.shape(pred[2])[-1])))
    # 这里concat在axis=1
    bbox = tf.concat(boxes, axis=1)
    objectness = tf.concat(objectness, axis=1)
    class_probs = tf.concat(class_probs, axis=1)

    final_batch_nms_bboxes = []
    final_batch_nms_scores = []
    final_batch_nms_classes = []
    valid_detection_nums = []
    batch_size = 1
    for b in range(batch_size):
        # 置信度*类别概率作为最终nms的排序依据
        cur_scores = objectness[b] * class_probs[b]

        # test模式下,batch纬度都是1了, 源码是直接squeeze因为test的batch=1
        # dscores = tf.squeeze(scores, axis=0)
        cur_dscores = tf.reshape(cur_scores, (-1, num_class))
        cur_bbox = tf.reshape(bbox[b], (-1, 4))

        # for i in range(num_class):
        #     cur_dscores_cls = cur_dscores[:, i]

        # 取所有类别中概率最大的,取每一行的概率最大值
        cur_scores = tf.reduce_max(cur_dscores, [1])
        print('cur_scores', cur_scores)
        # 取每一个概率最大值的索引
        cur_classes = tf.argmax(cur_dscores, 1)
        print('cur_classes', cur_classes)
        # 非极大值抑制,输出100个索引号和分类分数
        selected_indices, selected_scores = tf.image.non_max_suppression_with_scores(
            boxes=cur_bbox,
            scores=cur_scores,
            max_output_size=100,
            iou_threshold=0.5,
            score_threshold=0.007,
            soft_nms_sigma=0.5
        )
        print('selected_indices', selected_indices)
        print('selected_scores', selected_scores)
        # num_valid_nms_boxes = tf.shape(selected_indices)[0]
        # pad_num = self.yolo_max_boxes - num_valid_nms_boxes
        # 数量不够的话做padding
        # selected_indices = tf.concat([selected_indices, tf.zeros(self.yolo_max_boxes - num_valid_nms_boxes, tf.int32)],
        #                              0)
        # selected_scores = tf.concat([selected_scores, tf.zeros(self.yolo_max_boxes - num_valid_nms_boxes, tf.float32)],
        #                             -1)
        # 非极大值抑制后的有效数量
        vaild_num = tf.shape(selected_indices)[0]
        valid_detection_nums.append(vaild_num)
        pad_num = 100 - vaild_num

        # [N, (x1, y1, x2, y2)]
        # 挑选非极大值抑制后的anchorbox
        cur_bbox = tf.gather(cur_bbox, selected_indices)
        # 在anchorbox的下方填充多行0
        cur_bbox = tf.pad(cur_bbox, [[0, pad_num], [0, 0]])
        cur_bbox = tf.expand_dims(cur_bbox, axis=0)
        print('cur_bbox', cur_bbox)
        final_batch_nms_bboxes.append(cur_bbox)

        # [1, N]
        # 挑选非极大值抑制后的分类最大评分
        cur_scores = selected_scores
        cur_scores = tf.pad(cur_scores, [[0, pad_num]])
        cur_scores = tf.expand_dims(cur_scores, axis=0)
        print('cur_scores', cur_scores)
        final_batch_nms_scores.append(cur_scores)

        # [1, N]
        # 挑选非极大值抑制后的分类最大概率索引
        cur_classes = tf.gather(cur_classes, selected_indices)
        cur_classes = tf.pad(cur_classes, [[0, pad_num]])
        cur_classes = tf.expand_dims(cur_classes, axis=0)
        print('cur_classes', cur_classes)
        final_batch_nms_classes.append(cur_classes)

    final_batch_nms_bboxes = tf.concat(final_batch_nms_bboxes, axis=0)
    final_batch_nms_scores = tf.concat(final_batch_nms_scores, axis=0)
    final_batch_nms_classes = tf.concat(final_batch_nms_classes, axis=0)

运行结果

cur_scores tf.Tensor([0.00585971 0.00599038 0.00595966 ... 0.00554877 0.00553788 0.00553541], shape=(25200,), dtype=float32)
cur_classes tf.Tensor([45 47 24 ... 34  9 80], shape=(25200,), dtype=int64)
selected_indices tf.Tensor(
[4866  818 4764 4500 4662 4398  476 4728 2961 4176 4242 5097 3102 5118
 2823 4026 3444  321 4344 5028 3702 2484 2922 4686 2625 2781 3921 3306
 2445 4308 4833 2346 3768  723 4632 2304 4206], shape=(37,), dtype=int32)
selected_scores tf.Tensor(
[0.00931381 0.00896284 0.00856594 0.00833606 0.00819792 0.00801256
 0.00801134 0.00800271 0.00784801 0.00777876 0.00775182 0.00774399
 0.00770376 0.00768892 0.00760151 0.00760061 0.00757784 0.00753411
 0.00752614 0.00748221 0.00747792 0.00742623 0.00737575 0.00737422
 0.00735858 0.00733442 0.00733046 0.00730357 0.00720255 0.00719151
 0.00717866 0.00716299 0.0071443  0.00714163 0.00708044 0.0070156
 0.00700225], shape=(37,), dtype=float32)
cur_bbox tf.Tensor(
[[[0.47525817 0.7058009  0.65038186 0.8253993 ]
  [0.2290636  0.43899518 1.0120833  0.90559846]
  [0.6245662  0.67990446 0.8013499  0.8008765 ]
  [0.425923   0.6314819  0.5999857  0.74891996]
  [0.7734392  0.655851   0.9523676  0.774408  ]
  [0.57445204 0.6052354  0.7513912  0.72489446]
  [0.49578106 0.13718146 1.3462695  0.61382943]
  [0.32203868 0.68235976 0.5037507  0.79731053]
  [0.59745276 0.3060625  0.77827334 0.423403  ]
  [0.7227372  0.5557441  0.9029784  0.67373496]
  [0.27246937 0.58118933 0.4531304  0.6981532 ]
  [0.38823918 0.7577431  0.5873878  0.8722745 ]
  [0.77167207 0.33106077 0.9540772  0.44833946]
  [0.56604093 0.7558559  0.75961024 0.8745059 ]
  [0.44609636 0.28142396 0.6295416  0.39770314]
  [0.47225362 0.5309098  0.65343183 0.6481316 ]
  [0.620741   0.4065663  0.8049351  0.52262986]
  [0.22502528 0.         0.523556   0.5531311 ]
  [0.11964019 0.6067808  0.3058835  0.72229236]
  [0.81782186 0.729955   1.007774   0.84961873]
  [0.7708911  0.45628837 0.9547696  0.5728136 ]
  [0.6212652  0.20621267 0.8043301  0.32260957]
  [0.2703642  0.30675125 0.45527455 0.42211968]
  [0.         0.68376225 0.1580659  0.7943612 ]
  [0.79585207 0.23102117 0.9797478  0.3477431 ]
  [0.09455924 0.2818944  0.28103435 0.396766  ]
  [0.59682685 0.50606334 0.7788772  0.62299573]
  [0.4687725  0.38219234 0.65683836 0.49669072]
  [0.2954356  0.2064167  0.4800799  0.3222209 ]
  [0.82373786 0.5792823  1.0020474  0.7000315 ]
  [0.18977335 0.7078563  0.3858095  0.82072246]
  [0.46975034 0.1813986  0.6558301  0.29699862]
  [0.31987333 0.48147443 0.50574625 0.59682184]
  [0.         0.34287256 0.22398928 0.9016189 ]
  [0.52763104 0.6548126  0.69821584 0.7761345 ]
  [0.11914106 0.1818139  0.3063051  0.29622966]
  [0.         0.5828472  0.15521356 0.6949684 ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]]], shape=(1, 100, 4), dtype=float32)
cur_scores tf.Tensor(
[[0.00931381 0.00896284 0.00856594 0.00833606 0.00819792 0.00801256
  0.00801134 0.00800271 0.00784801 0.00777876 0.00775182 0.00774399
  0.00770376 0.00768892 0.00760151 0.00760061 0.00757784 0.00753411
  0.00752614 0.00748221 0.00747792 0.00742623 0.00737575 0.00737422
  0.00735858 0.00733442 0.00733046 0.00730357 0.00720255 0.00719151
  0.00717866 0.00716299 0.0071443  0.00714163 0.00708044 0.0070156
  0.00700225 0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]], shape=(1, 100), dtype=float32)
cur_classes tf.Tensor(
[[29 28 29 29 29 29 39 29 29 29 29 29 29 29 29 29 29 27 29 29 29 29 29 29
  29 29 29 29 29 29 29 29 29 20 29 29 29  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0]], shape=(1, 100), dtype=int64)

 基于COCO数据集进行训练

这里我们将传入改为tensorflow的layers.Input

x1, x2, x3 = darknet.build_darknet(inputs, "darknet")

要使用COCO数据集,我们需要安装COCO的工具包

pip install pycocotools==2.0.0
optimizer = optimizers.Adam(learning_rate=0.001)
train_data = CoCoDataGenrator(
    coco_annotation_file="./data/instances_val2017.json",
    img_shape=[640, 640, 3],
    batch_size=5,
    max_instances=100
)

我们先来看一下instances_val2017.json包含的格式:ino、licenses,这两项不重要,可以为空。images、annotations代表图像和标注。images为list,每个list项是一个dict,其中一个图片的内容是(coco_url代表图像的url地址,height、width代表图像的高和宽):

"images": [{"license": 4,"file_name": "000000397133.jpg","coco_url": "http://images.cocodataset.org/val2017/000000397133.jpg",
"height": 427,
"width": 640,"date_captured": "2013-11-14 17:02:52","flickr_url": "http://farm7.staticflickr.com/6116/6255196340_da26cf2c9e_z.jpg","id": 397133}]

annotations为list,每个list项是一个dict,其中一个图片的内容是:

"annotations": [{"segmentation": [[510.66,423.01,511.72,420.03,510.45,416.0,510.34,413.02,510.77,410.26,510.77,407.5,510.34,405.16,511.51,402.83,511.41,400.49,510.24,398.16,509.39,397.31,504.61,399.22,502.17,399.64,500.89,401.66,500.47,402.08,499.09,401.87,495.79,401.98,490.59,401.77,488.79,401.77,485.39,398.58,483.9,397.31,481.56,396.35,478.48,395.93,476.68,396.03,475.4,396.77,473.92,398.79,473.28,399.96,473.49,401.87,474.56,403.47,473.07,405.59,473.39,407.71,476.68,409.41,479.23,409.73,481.56,410.69,480.4,411.85,481.35,414.93,479.86,418.65,477.32,420.03,476.04,422.58,479.02,422.58,480.29,423.01,483.79,419.93,486.66,416.21,490.06,415.57,492.18,416.85,491.65,420.24,492.82,422.9,493.56,424.39,496.43,424.6,498.02,423.01,498.13,421.31,497.07,420.03,497.07,415.15,496.33,414.51,501.1,411.96,502.06,411.32,503.02,415.04,503.33,418.12,501.1,420.24,498.98,421.63,500.47,424.39,505.03,423.32,506.2,421.31,507.69,419.5,506.31,423.32,510.03,423.01,510.45,423.01]],
"area": 702.1057499999998,
"iscrowd": 0,
"image_id": 289343,
"bbox": [473.07,395.93,38.65,28.67],
"category_id": 18,
"id": 1768}]

image_id是对应的images的id,一个image可能会有多个annotation,因为每个annotation只是表示一个目标的label。category_id是类别的标识。segmentation是语义分割的label,area表示语义分割的区域大小,iscrowd表示是否是人群,bbox是目标框。

CoCoDataGenrator类代码如下

import cv2
from pycocotools.coco import COCO
import numpy as np
import skimage.io as io


class CoCoDataGenrator:
    def __init__(self,
                 coco_annotation_file,
                 img_shape=(640, 640, 3),
                 batch_size=1,
                 max_instances=100,
                 include_crowd=False,
                 include_mask=False,
                 include_keypoint=False):
        self.img_shape = img_shape
        self.batch_size = batch_size
        self.max_instances = max_instances
        self.include_crowd = include_crowd
        self.include_mask = include_mask
        self.include_keypoint = include_keypoint

        self.current_batch_index = 0
        self.total_batch_size = 0
        self.img_ids = []
        self.coco = COCO(annotation_file=coco_annotation_file)
        self.load_data()

    def load_data(self):
        # 初步过滤数据是否包含crowd
        target_img_ids = []
        for k in self.coco.imgToAnns:
            annos = self.coco.imgToAnns[k]
            print(annos)
            if annos:
                annos = list(filter(lambda x: x['iscrowd'] == self.include_crowd, annos))
                if annos:
                    target_img_ids.append(k)
        self.total_batch_size = len(target_img_ids) // self.batch_size
        self.img_ids = target_img_ids

    def next_batch(self):
        if self.current_batch_index >= self.total_batch_size:
            self.current_batch_index = 0
            self._on_epoch_end()

        batch_img_ids = self.img_ids[self.current_batch_index * self.batch_size:
                                     (self.current_batch_index + 1) * self.batch_size]
        batch_imgs = []
        batch_bboxes = []
        batch_labels = []
        batch_masks = []
        batch_keypoints = []
        for img_id in batch_img_ids:
            # {"img":, "bboxes":, "labels":, "masks":, "key_points":}
            data = self._data_generation(image_id=img_id)
            if len(np.shape(data['img'])) > 0:
                batch_imgs.append(data['img'])

                if len(data['labels']) > self.max_instances:
                    batch_bboxes.append(data['bboxes'][:self.max_instances, :])
                    batch_labels.append(data['labels'][:self.max_instances])
                else:
                    pad_num = self.max_instances - len(data['labels'])
                    batch_bboxes.append(np.pad(data['bboxes'], [(0,pad_num), (0, 0)]))
                    batch_labels.append(np.pad(data['labels'], [(0,pad_num)]))

                if self.include_mask:
                    batch_masks.append(data['masks'])

                if self.include_keypoint:
                    batch_keypoints.append(data['keypoints'])

        self.current_batch_index += 1

        if len(batch_imgs) < self.batch_size:
            return self.next_batch()

        output = {
            'imgs': np.array(batch_imgs, dtype=np.int32),
            'bboxes': np.array(batch_bboxes, dtype=np.int16),
            'labels': np.array(batch_labels, dtype=np.int8),
            'masks': np.array(batch_masks, dtype=np.int8),
            'keypoints': np.array(batch_keypoints, dtype=np.int16)
        }

        return output

    def _on_epoch_end(self):
        np.random.shuffle(self.img_ids)

    def _resize_im(self, origin_im, bboxes):
        """ 对图片/mask/box resize

        :param origin_im
        :param bboxes
        :return im_blob: [h, w, 3]
                gt_boxes: [N, [ymin, xmin, ymax, xmax]]
        """
        im_shape = np.shape(origin_im)
        im_size_max = np.max(im_shape[0:2])
        im_scale = float(self.img_shape[0]) / float(im_size_max)

        # resize原始图片
        im_resize = cv2.resize(origin_im, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR)
        im_resize_shape = np.shape(im_resize)
        im_blob = np.zeros(self.img_shape, dtype=np.float32)
        im_blob[0:im_resize_shape[0], 0:im_resize_shape[1], :] = im_resize

        # resize对应边框
        bboxes_resize = np.array(bboxes * im_scale, dtype=np.int16)

        return im_blob, bboxes_resize

    def _resize_mask(self, origin_masks):
        """ resize mask数据
        :param origin_mask:
        :return: mask_resize: [h, w, instance]
                 gt_boxes: [N, [ymin, xmin, ymax, xmax]]
        """
        mask_shape = np.shape(origin_masks)
        mask_size_max = np.max(mask_shape[0:2])
        im_scale = float(self.img_shape[0]) / float(mask_size_max)

        # resize mask/box
        gt_boxes = []
        masks_resize = []
        for m in origin_masks:
            m = np.array(m, dtype=np.float32)
            m_resize = cv2.resize(m, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR)
            m_resize = np.array(m_resize >= 0.5, dtype=np.int8)

            # 计算bdbox
            h, w = np.shape(m_resize)
            rows, cols = np.where(m_resize)
            # [xmin, ymin, xmax, ymax]
            xmin = np.min(cols) if np.min(cols) >= 0 else 0
            ymin = np.min(rows) if np.min(rows) >= 0 else 0
            xmax = np.max(cols) if np.max(cols) <= w else w
            ymax = np.max(rows) if np.max(rows) <= h else h
            bdbox = [xmin, ymin, xmax, ymax]
            gt_boxes.append(bdbox)

            mask_blob = np.zeros((self.img_shape[0], self.img_shape[1], 1), dtype=np.float32)
            mask_blob[0:h, 0:w, 0] = m_resize
            masks_resize.append(mask_blob)

        # [instance_num, [ymin, xmin, ymax, xmax]]
        gt_boxes = np.array(gt_boxes, dtype=np.int16)
        # [h, w, instance_num]
        masks_resize = np.concatenate(masks_resize, axis=-1)

        return masks_resize, gt_boxes

    def _data_generation(self, image_id):
        """ 拉取coco标记数据, 目标边框, 类别, mask
        :param image_id:
        :return:
        """

        anno_ids = self.coco.getAnnIds(imgIds=image_id, iscrowd=self.include_crowd)
        bboxes = []
        labels = []
        masks = []
        keypoints = []
        for i in anno_ids:
            # 边框, 处理成左上右下座标
            ann = self.coco.anns[i]
            bbox = ann['bbox']
            xmin, ymin, w, h = bbox
            xmin = int(xmin)
            ymin = int(ymin)
            xmax = int(xmin + w)
            ymax = int(ymin + h)
            bboxes.append([xmin, ymin, xmax, ymax])
            # 类别ID
            label = ann['category_id']
            labels.append(label)
            # 实例分割
            if self.include_mask:
                mask = self.coco.annToMask(ann)
                masks.append(mask)
            if self.include_keypoint and ann.get('keypoints'):
                keypoint = ann['keypoints']
                # 处理成[x,y,v] 其中v=0表示没有此点,v=1表示被挡不可见,v=2表示可见
                keypoint = np.reshape(keypoint, [-1, 3])
                keypoints.append(keypoint)

        # 输出包含5个东西, 不需要则为空
        outputs = {
            "img": [],
            "labels": [],
            "bboxes": [],
            "masks": [],
            "keypoints": []
        }

        # 处理最终数据 mask
        if self.include_mask:
            # [N, h, w]
            masks, bboxes = self._resize_mask(origin_masks=masks)
            outputs['masks'] = masks
            outputs['bboxes'] = bboxes

        # 处理最终数据 keypoint
        if self.include_keypoint:
            keypoints = np.array(keypoints, dtype=np.int8)
            outputs['keypoints'] = keypoints

        img = io.imread(self.coco.imgs[image_id]['coco_url'])
        if len(np.shape(img)) < 2:
            return outputs
        elif len(np.shape(img)) == 2:
            img = np.expand_dims(img, axis=-1)
            img = np.pad(img, [(0,0), (0,0), (0,2)])

        labels = np.array(labels, dtype=np.int8)
        bboxes = np.array(bboxes, dtype=np.int16)
        img_resize, bboxes_resize = self._resize_im(origin_im=img, bboxes=bboxes)
        outputs['img'] = img_resize
        outputs['labels'] = labels
        outputs['bboxes'] = bboxes_resize

        return outputs

现在我们开始准备训练COCO数据集

# 获取数据集的分类类别
classes = train_data.coco.cats
log_dir = "./logs"
summary_writer = tf.summary.create_file_writer(log_dir)
epochs = 101
for epoch in range(epochs):
    if epoch % 20 == 0 and epoch != 0:
        model.save_weights(log_dir + '/yolov3-tf-{}.h5'.format(epoch))
    for batch in range(train_data.total_batch_size):
        with tf.GradientTape() as tape:
            data = train_data.next_batch()
            # 获取样本的图像,边框,标签数据
            gt_imgs = data['imgs'] / 255.
            gt_boxes = data['bboxes'] / image_shape[0]
            gt_classes = data['labels']
            print('gt_imgs', gt_imgs)
            print('gt_boxes', gt_boxes)
            print('gt_classes', gt_classes)
            # 构建YOLO训练所需要的目标值
            yolo_targets = transform_targets(
                gt_boxes=gt_boxes,
                gt_lables=gt_classes,
                anchors=anchors,
                anchor_masks=anchor_masks,
                im_size=image_shape[0]
            )
            yolo_preds = model(gt_imgs, training=True)

运行结果

gt_imgs [[[[0.15294118 0.15294118 0.15294118]
   [0.09803922 0.09803922 0.09803922]
   [0.08235294 0.08235294 0.08235294]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.14117647 0.14117647 0.14117647]
   [0.14117647 0.14117647 0.14117647]
   [0.12156863 0.12156863 0.12156863]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.07058824 0.07058824 0.07058824]
   [0.09803922 0.09803922 0.09803922]
   [0.10196078 0.10196078 0.10196078]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  ...

  [[0.13333333 0.13333333 0.13333333]
   [0.1254902  0.1254902  0.1254902 ]
   [0.14509804 0.14509804 0.14509804]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.1254902  0.1254902  0.1254902 ]
   [0.14901961 0.14901961 0.14901961]
   [0.16862745 0.16862745 0.16862745]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.14509804 0.14509804 0.14509804]
   [0.16470588 0.16470588 0.16470588]
   [0.15294118 0.15294118 0.15294118]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]]


 [[[0.89803922 0.89803922 0.89803922]
   [0.89411765 0.89411765 0.89411765]
   [0.89411765 0.89411765 0.89411765]
   ...
   [0.78039216 0.78039216 0.77254902]
   [0.83921569 0.8627451  0.85490196]
   [0.91764706 0.94117647 0.94117647]]

  [[0.89803922 0.89803922 0.89803922]
   [0.89411765 0.89411765 0.89411765]
   [0.89411765 0.89411765 0.89411765]
   ...
   [0.79607843 0.79607843 0.78823529]
   [0.84705882 0.87058824 0.8627451 ]
   [0.90588235 0.92941176 0.92941176]]

  [[0.89803922 0.89803922 0.89803922]
   [0.89411765 0.89411765 0.89411765]
   [0.89411765 0.89411765 0.89411765]
   ...
   [0.80392157 0.81176471 0.8       ]
   [0.85490196 0.87843137 0.87058824]
   [0.89411765 0.90980392 0.91372549]]

  ...

  [[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]]


 [[[0.49803922 0.70196078 0.17647059]
   [0.49019608 0.69411765 0.16862745]
   [0.47843137 0.68235294 0.15686275]
   ...
   [0.50588235 0.70588235 0.15686275]
   [0.51764706 0.71764706 0.17254902]
   [0.49019608 0.69019608 0.15294118]]

  [[0.58431373 0.78823529 0.25098039]
   [0.58039216 0.78431373 0.25098039]
   [0.57647059 0.78039216 0.24313725]
   ...
   [0.60392157 0.80784314 0.24705882]
   [0.60392157 0.80392157 0.25098039]
   [0.56470588 0.76862745 0.22352941]]

  [[0.60784314 0.82352941 0.24313725]
   [0.6        0.81960784 0.23921569]
   [0.6        0.81568627 0.24705882]
   ...
   [0.61960784 0.82352941 0.25098039]
   [0.6        0.80392157 0.23921569]
   [0.55686275 0.76078431 0.2       ]]

  ...

  [[0.45882353 0.69803922 0.09803922]
   [0.45882353 0.69803922 0.09803922]
   [0.45882353 0.69803922 0.09803922]
   ...
   [0.61176471 0.85490196 0.17647059]
   [0.61176471 0.83921569 0.17647059]
   [0.59607843 0.82745098 0.16862745]]

  [[0.45882353 0.69803922 0.09803922]
   [0.45882353 0.69803922 0.09803922]
   [0.45882353 0.69803922 0.09803922]
   ...
   [0.61176471 0.85098039 0.18431373]
   [0.61176471 0.84313725 0.18431373]
   [0.6        0.83137255 0.17254902]]

  [[0.45882353 0.69803922 0.09803922]
   [0.45882353 0.69803922 0.09803922]
   [0.45882353 0.69803922 0.09803922]
   ...
   [0.61176471 0.85490196 0.18431373]
   [0.61176471 0.84313725 0.18431373]
   [0.60392157 0.83529412 0.18431373]]]


 [[[0.24313725 0.27843137 0.2745098 ]
   [0.09411765 0.10980392 0.10588235]
   [0.08235294 0.08235294 0.0745098 ]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.25098039 0.28235294 0.29019608]
   [0.10196078 0.11764706 0.12156863]
   [0.08627451 0.08627451 0.08627451]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.24705882 0.27843137 0.29019608]
   [0.09803922 0.11372549 0.1254902 ]
   [0.07843137 0.07843137 0.07843137]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  ...

  [[0.05490196 0.05490196 0.05490196]
   [0.05490196 0.05490196 0.05490196]
   [0.05490196 0.05490196 0.05490196]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.05490196 0.05490196 0.05490196]
   [0.05490196 0.05490196 0.05490196]
   [0.05490196 0.05490196 0.05490196]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.05882353 0.05882353 0.05882353]
   [0.05490196 0.05490196 0.05490196]
   [0.05490196 0.05490196 0.05490196]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]]


 [[[0.68627451 0.66666667 0.65490196]
   [0.68235294 0.66666667 0.65490196]
   [0.6627451  0.64705882 0.63529412]
   ...
   [0.50980392 0.49019608 0.4745098 ]
   [0.49803922 0.47843137 0.4627451 ]
   [0.48235294 0.47058824 0.45098039]]

  [[0.6745098  0.65490196 0.64313725]
   [0.6745098  0.65490196 0.64313725]
   [0.67843137 0.65882353 0.64705882]
   ...
   [0.52156863 0.50196078 0.48627451]
   [0.50588235 0.48627451 0.47058824]
   [0.49411765 0.4745098  0.45882353]]

  [[0.66666667 0.64313725 0.64313725]
   [0.65098039 0.62745098 0.62745098]
   [0.6627451  0.63921569 0.63921569]
   ...
   [0.54901961 0.51764706 0.50588235]
   [0.51764706 0.49803922 0.48235294]
   [0.49803922 0.47843137 0.4627451 ]]

  ...

  [[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]

  [[0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   ...
   [0.         0.         0.        ]
   [0.         0.         0.        ]
   [0.         0.         0.        ]]]]
gt_boxes [[[0.7390625 0.6171875 0.7984375 0.6609375]
  [0.31875   0.3671875 0.4125    0.64375  ]
  [0.        0.7796875 0.5296875 0.9453125]
  ...
  [0.        0.        0.        0.       ]
  [0.        0.        0.        0.       ]
  [0.        0.        0.        0.       ]]

 [[0.425     0.3125    0.6609375 0.7484375]
  [0.2828125 0.134375  0.325     0.2484375]
  [0.271875  0.        0.6796875 0.34375  ]
  ...
  [0.        0.        0.        0.       ]
  [0.        0.        0.        0.       ]
  [0.        0.        0.        0.       ]]

 [[0.2015625 0.31875   0.809375  0.9015625]
  [0.0953125 0.08125   0.9140625 0.71875  ]
  [0.4703125 0.1140625 0.5140625 0.15625  ]
  ...
  [0.        0.        0.        0.       ]
  [0.        0.        0.        0.       ]
  [0.        0.        0.        0.       ]]

 [[0.175     0.240625  0.7484375 0.9890625]
  [0.        0.        0.        0.       ]
  [0.        0.        0.        0.       ]
  ...
  [0.        0.        0.        0.       ]
  [0.        0.        0.        0.       ]
  [0.        0.        0.        0.       ]]

 [[0.3125    0.1390625 0.9375    0.53125  ]
  [0.146875  0.        0.521875  0.3296875]
  [0.        0.        0.        0.       ]
  ...
  [0.        0.        0.        0.       ]
  [0.        0.        0.        0.       ]
  [0.        0.        0.        0.       ]]]
gt_classes [[18  1 15  2  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0]
 [18 44 70  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0]
 [18  4 47 47  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0]
 [18  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0]
 [18  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0]]

这里我们来看一下transform_targets方法,该方法是将COCO数据集中ground truth的标记值同9种anchor计算IoU(交并比)来得出最符合的一种box的座标来作为真实box标签,最后输出的是该box的座标值以及置信度1和该box的分类标签。由此我们知道YOLO并不是以图像中的标记值直接作为标签的,而是以每个特征层划分的3种不同尺寸的单元格为基本单位来作为标签值的。

def transform_targets(gt_boxes, gt_lables, anchors, anchor_masks, im_size):
    """ 计算YOLO训练目标值
    :param gt_boxes: [batch, num_boxes, (x1, y1, x2, y2)]
    :param gt_lables: [batch, num_boxes]
    :param anchors: [(10, 13), (16, 30), (33, 23), (30, 61), (62, 45),
                    (59, 119), (116, 90), (156, 198), (373, 326)] / im_size
    :param anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
    :param im_size:
    :return:  ([N, grid, grid, anchors, [x1, y1, x2, y2, obj, class]], [], [])
    """
    y_outs = []
    # 原图的1/32
    grid_size = im_size // 32

    # 计算9种anchor的面积, 这里anchor都已经归一化
    anchors = np.array(anchors, np.float32)
    anchor_area = anchors[..., 0] * anchors[..., 1]
    print('anchor_area', anchor_area)

    # 计算gt_box的宽高, 这里宽高也已经归一化
    box_wh = gt_boxes[..., 2:4] - gt_boxes[..., 0:2]
    box_wh = np.tile(np.expand_dims(box_wh, axis=-2),
                     (1, 1, np.shape(anchors)[0], 1))
    box_area = box_wh[..., 0] * box_wh[..., 1]
    print('box_area', box_area)

    # 计算iou
    intersection = np.minimum(box_wh[..., 0], anchors[..., 0]) * np.minimum(box_wh[..., 1], anchors[..., 1])
    iou = intersection / (box_area + anchor_area - intersection)
    print('iou', iou)
    print('ioushape', iou.shape)
    # 获取ground truth与9种anchor的最大iou的索引
    anchor_idx = np.array(np.argmax(iou, axis=-1), np.float32)
    anchor_idx = np.expand_dims(anchor_idx, axis=-1)
    gt_labels = np.expand_dims(gt_lables, axis=-1)

    # 拼接最后的结果
    y_train = np.concatenate([gt_boxes, gt_labels, anchor_idx], axis=-1)
    # print(y_train)
    # 对于每一个特征层(一共三层)计算一次最终目标值
    for anchor_idxs in anchor_masks:
        y_outs.append(transform_targets_for_output(y_train, grid_size, anchor_idxs))
        grid_size *= 2

    return tuple(y_outs)

def transform_targets_for_output(y_true, grid_size, anchor_idxs):
    """ 生成YOLO某一层output的目标值
    :param y_true: [N, boxes, (x1, y1, x2, y2, class, best_anchor)]
    :param grid_size:
    :param anchor_idxs: [,,]
    :return: y_true_out: [N, grid, grid, anchors, [x1, y1, x2, y2, obj, class]]
    """
    # y_true: [N, boxes, (x1, y1, x2, y2, class, best_anchor)]
    print('y_true', y_true)
    N, num_boxes, _ = np.shape(y_true)

    # y_true_out: [N, grid, grid, anchors, [x1, y1, x2, y2, obj, class]]
    y_true_out = np.zeros((N, grid_size, grid_size, np.shape(anchor_idxs)[0], 6), dtype=np.float32)

    anchor_idxs = np.array(anchor_idxs, np.int32)
    # indexes = tf.TensorArray(tf.int32, 1, dynamic_size=True)
    # updates = tf.TensorArray(tf.float32, 1, dynamic_size=True)
    for i in np.arange(N):
        for j in np.arange(num_boxes):
            # 这里如果是padding的数据则跳过
            if y_true[i][j][2] == 0:
                continue
            # print(y_true[i][j][5])
            # 判断跟传进来的anchor idx哪个一样, y_true[i][j][5]为9个best anchor中的某一个
            anchor_eq = anchor_idxs == y_true[i][j][5]
            # print(anchor_eq)

            # 存在一个一样
            if np.any(anchor_eq):
                box = y_true[i][j][0:4]
                # 计算中心点
                box_xy = (y_true[i][j][0:2] + y_true[i][j][2:4]) / 2
                anchor_idx = np.array(np.where(anchor_eq)[0], np.int32)
                grid_xy = np.array(box_xy // (1 / grid_size), np.int32)

                y_true_out[i, grid_xy[1], grid_xy[0], anchor_idx[0], :] = \
                    [box[0], box[1], box[2], box[3], 1, y_true[i, j, 4]]
                # print([box[0], box[1], box[2], box[3], 1, y_true[i,j,4]])
                # grid[y][x][anchor] = (tx, ty, bw, bh, obj, class)
                # indexes = indexes.write(
                #     idx, [i, grid_xy[1], grid_xy[0], anchor_idx[0][0]])
                # updates = updates.write(
                #     idx, [box[0], box[1], box[2], box[3], 1, y_true[i][j][4]])

    # tf.print(indexes.stack())
    # tf.print(updates.stack())
    return y_true_out

运行结果

anchor_area [0.00083008 0.00545898 0.02046387 0.02223633 0.06881836 0.07457275
 0.2064917  0.21621095 0.5328125 ]
box_area [[[0.00259766 0.00259766 0.00259766 ... 0.00259766 0.00259766 0.00259766]
  [0.02592773 0.02592773 0.02592773 ... 0.02592773 0.02592773 0.02592773]
  [0.08772949 0.08772949 0.08772949 ... 0.08772949 0.08772949 0.08772949]
  ...
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]]

 [[0.102854   0.102854   0.102854   ... 0.102854   0.102854   0.102854  ]
  [0.00481201 0.00481201 0.00481201 ... 0.00481201 0.00481201 0.00481201]
  [0.14018555 0.14018555 0.14018555 ... 0.14018555 0.14018555 0.14018555]
  ...
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]]

 [[0.35424072 0.35424072 0.35424072 ... 0.35424072 0.35424072 0.35424072]
  [0.52195312 0.52195312 0.52195312 ... 0.52195312 0.52195312 0.52195312]
  [0.0018457  0.0018457  0.0018457  ... 0.0018457  0.0018457  0.0018457 ]
  ...
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]]

 [[0.42918213 0.42918213 0.42918213 ... 0.42918213 0.42918213 0.42918213]
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]
  ...
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]]

 [[0.24511719 0.24511719 0.24511719 ... 0.24511719 0.24511719 0.24511719]
  [0.12363281 0.12363281 0.12363281 ... 0.12363281 0.12363281 0.12363281]
  [0.         0.         0.         ... 0.         0.         0.        ]
  ...
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]]]
iou [[[0.31954888 0.47584972 0.12693868 ... 0.01257996 0.01201445 0.00487537]
  [0.03201507 0.21054614 0.66947811 ... 0.12556309 0.11991869 0.04866203]
  [0.00946179 0.06222519 0.1874598  ... 0.25776757 0.40575879 0.1646536 ]
  ...
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]]

 [[0.00807045 0.05307508 0.19896033 ... 0.49810238 0.33256859 0.19303978]
  [0.17250127 0.50089186 0.23514674 ... 0.02330366 0.0222561  0.00903134]
  [0.00592128 0.03894114 0.14597701 ... 0.56491694 0.62916833 0.26310485]
  ...
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]]

 [[0.00234326 0.01541038 0.05776825 ... 0.58291346 0.56153389 0.66485065]
  [0.00159033 0.01045876 0.03920633 ... 0.39561347 0.4142344  0.8811481 ]
  [0.44973546 0.33810375 0.09019327 ... 0.00893839 0.00853659 0.00346408]
  ...
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]]

 [[0.00193409 0.01271951 0.04768108 ... 0.48112834 0.42830977 0.69437937]
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]
  ...
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]]

 [[0.00338645 0.02227092 0.08348605 ... 0.46233081 0.84243369 0.46004401]
  [0.00671406 0.04415482 0.16552132 ... 0.57129077 0.57181569 0.23203813]
  [0.         0.         0.         ... 0.         0.         0.        ]
  ...
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]
  [0.         0.         0.         ... 0.         0.         0.        ]]]
ioushape (5, 100, 9)
y_true [[[ 0.7390625  0.6171875  0.7984375  0.6609375 18.         1.       ]
  [ 0.31875    0.3671875  0.4125     0.64375    1.         2.       ]
  [ 0.         0.7796875  0.5296875  0.9453125 15.         5.       ]
  ...
  [ 0.         0.         0.         0.         0.         0.       ]
  [ 0.         0.         0.         0.         0.         0.       ]
  [ 0.         0.         0.         0.         0.         0.       ]]

 [[ 0.425      0.3125     0.6609375  0.7484375 18.         4.       ]
  [ 0.2828125  0.134375   0.325      0.2484375 44.         1.       ]
  [ 0.271875   0.         0.6796875  0.34375   70.         7.       ]
  ...
  [ 0.         0.         0.         0.         0.         0.       ]
  [ 0.         0.         0.         0.         0.         0.       ]
  [ 0.         0.         0.         0.         0.         0.       ]]

 [[ 0.2015625  0.31875    0.809375   0.9015625 18.         8.       ]
  [ 0.0953125  0.08125    0.9140625  0.71875    4.         8.       ]
  [ 0.4703125  0.1140625  0.5140625  0.15625   47.         0.       ]
  ...
  [ 0.         0.         0.         0.         0.         0.       ]
  [ 0.         0.         0.         0.         0.         0.       ]
  [ 0.         0.         0.         0.         0.         0.       ]]

 [[ 0.175      0.240625   0.7484375  0.9890625 18.         8.       ]
  [ 0.         0.         0.         0.         0.         0.       ]
  [ 0.         0.         0.         0.         0.         0.       ]
  ...
  [ 0.         0.         0.         0.         0.         0.       ]
  [ 0.         0.         0.         0.         0.         0.       ]
  [ 0.         0.         0.         0.         0.         0.       ]]

 [[ 0.3125     0.1390625  0.9375     0.53125   18.         7.       ]
  [ 0.146875   0.         0.521875   0.3296875  1.         5.       ]
  [ 0.         0.         0.         0.         0.         0.       ]
  ...
  [ 0.         0.         0.         0.         0.         0.       ]
  [ 0.         0.         0.         0.         0.         0.       ]
  [ 0.         0.         0.         0.         0.         0.       ]]]

然后是构建损失函数和梯度下降

# 3层输出分别计算损失
total_xy_loss = total_wh_loss = total_obj_loss = total_class_loss = 0.
for i in range(3):
    # 获取目标标签边框,置信度,分类
    true_box, true_obj, true_class = np.split(yolo_targets[i], (4, 5), axis=-1)
    # 获取预测边框,置信度,分类和边框偏移量
    pred_box, pred_obj, pred_class, pred_box_xywh = yolo_preds[i]

    xy_loss, wh_loss, obj_loss, class_loss = loss(
        pred_box=pred_box,
        pred_box_xywh=pred_box_xywh,
        true_box=true_box,
        pred_obj=pred_obj,
        true_obj=true_obj,
        pred_class=pred_class,
        true_class=true_class,
        anchors=anchors[anchor_masks[i]],
        ignore_thresh=0.5
    )
    # print(i, tf.reduce_mean(xy_loss),  tf.reduce_mean(obj_loss))

    total_xy_loss += tf.reduce_mean(xy_loss)
    total_wh_loss += tf.reduce_mean(wh_loss)
    total_obj_loss += tf.reduce_mean(obj_loss)
    total_class_loss += tf.reduce_mean(class_loss)

total_loss = total_xy_loss + total_wh_loss + total_obj_loss + total_class_loss
grad = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(grad, model.trainable_variables))

这里我们来看一下loss函数

def loss(pred_box, pred_box_xywh, true_box, pred_obj, true_obj, pred_class, true_class, anchors, ignore_thresh,
         balanced_rate=5):
    # def loss(preds, targets, anchors, ignore_thresh)
    """
    :param pred_box: [batch_size, grid, grid, anchors, (x1, y1, x2, y2)]
    :param pred_box_xywh: [batch_size, grid, grid, anchors, (tx, ty, tw, th)]
    :param true_box: [batch_size, grid, grid, anchors, (x1, y1, x2, y2)]
    :param pred_obj: [batch_size, grid, grid, anchors, 1]
    :param true_obj: [batch_size, grid, grid, anchors, 1]
    :param pred_class: [batch_size, grid, grid, anchors, num_classes]
    :param true_class: [batch_size, grid, grid, anchors, 1]
    :param anchors: [[w1,h1],[w2,h2],[w3,h3]]
    :param ignore_thresh: 正负样本iou阈值
    :param balanced_rate: 正负样本平衡比例
    :return:
    """
    # [batch_size, grid, grid, anchors, 2]
    # 获取座标偏移量
    pred_xy = pred_box_xywh[..., 0:2]
    # [batch_size, grid, grid, anchors, 2]
    # 获取座标偏移量的宽高
    pred_wh = pred_box_xywh[..., 2:4]

    # true_box, true_obj, true_class_idx = tf.split(true_box, (4, 1, 1), axis=-1)
    # 获取目标边框中心点座标
    true_xy = (true_box[..., 0:2] + true_box[..., 2:4]) / 2
    # 获取目标边框的宽高
    true_wh = true_box[..., 2:4] - true_box[..., 0:2]

    # 小目标检测常数
    box_loss_scale = 2 - true_wh[..., 0] * true_wh[..., 1]

    # 3. inverting the pred box equations
    grid_size = tf.shape(true_box)[1]
    grid = tf.meshgrid(tf.range(grid_size), tf.range(grid_size))
    # [grid_size, grid_size, 1, 2]
    grid = tf.expand_dims(tf.stack(grid, axis=-1), axis=2)
    # 计算true_box的平移缩放量
    # [batch_size, grid, grid, anchors, 2]
    # 获取目标边框偏移量
    true_xy = true_xy * tf.cast(grid_size, tf.float32) - tf.cast(grid, tf.float32)
    # [batch_size, grid, grid, anchors, 2]

    true_wh = tf.math.log(true_wh / anchors)
    true_wh = tf.where(tf.math.is_inf(true_wh), tf.zeros_like(true_wh), true_wh)

    # 4. calculate all masks
    # [batch_size, grid, grid, anchors]
    obj_mask = tf.squeeze(true_obj, -1)
    # 构建正负样本的数量
    positive_num = tf.cast(tf.reduce_sum(obj_mask), tf.int32) + 1
    negative_num = balanced_rate * positive_num
    # ignore false positive when iou is over threshold
    # [batch_size, grid, grid, anchors, num_gt_box] => [batch_size, grid, grid, anchors, 1]
    # 对预测边框和目标边框计算IoU
    best_iou = tf.map_fn(
        lambda x: tf.reduce_max(broadcast_iou(x[0], tf.boolean_mask(
            x[1], tf.cast(x[2], tf.bool))), axis=-1),
        (pred_box, true_box, obj_mask),
        tf.float32)
    # [batch_size, grid, grid, anchors, 1]
    ignore_mask = tf.cast(best_iou < ignore_thresh, tf.float32)
    # 这里做了下样本均衡.
    ignore_num = tf.cast(tf.reduce_sum(ignore_mask), tf.int32)
    if ignore_num > negative_num:
        neg_inds = tf.random.shuffle(tf.where(ignore_mask))[:negative_num]
        neg_inds = tf.expand_dims(neg_inds, axis=1)
        ones = tf.ones(tf.shape(neg_inds)[0], tf.float32)
        ones = tf.expand_dims(ones, axis=1)
        # 更新mask
        ignore_mask = tf.zeros_like(ignore_mask, tf.float32)
        ignore_mask = tf.tensor_scatter_nd_add(ignore_mask, neg_inds, ones)

    # 5. calculate all losses
    # [batch_size, grid, grid, anchors]
    # 构建边框偏移量损失
    xy_loss = obj_mask * box_loss_scale * tf.reduce_sum(tf.square(true_xy - pred_xy), axis=-1)
    # [batch_size, grid, grid, anchors]
    # 构建边框宽高损失
    wh_loss = obj_mask * box_loss_scale * tf.reduce_sum(tf.square(true_wh - pred_wh), axis=-1)

    # obj_loss = binary_crossentropy(true_obj, pred_obj)
    conf_focal = tf.pow(obj_mask - tf.squeeze(pred_obj, -1), 2)
    # 构建置信度损失
    obj_loss = losses.binary_crossentropy(true_obj, pred_obj)
    obj_loss = conf_focal * (obj_mask * obj_loss + (1 - obj_mask) * ignore_mask * obj_loss)

    # obj_loss = tf.keras.losses.binary_crossentropy(true_obj, pred_obj)
    # 这里除了正样本会计算损失, 负样本低于一定置信的也计算损失
    # obj_loss = obj_mask * obj_loss + (1 - obj_mask) * ignore_mask * obj_loss

    # TODO: use binary_crossentropy instead
    # class_loss = obj_mask * sparse_categorical_crossentropy(true_class_idx, pred_class)
    # 构建分类损失
    class_loss = obj_mask * losses.sparse_categorical_crossentropy(true_class, pred_class)

    # 6. sum over (batch, gridx, gridy, anchors) => (batch, 1)
    xy_loss = tf.reduce_sum(xy_loss, axis=(1, 2, 3))
    wh_loss = tf.reduce_sum(wh_loss, axis=(1, 2, 3))
    obj_loss = tf.reduce_sum(obj_loss, axis=(1, 2, 3))
    class_loss = tf.reduce_sum(class_loss, axis=(1, 2, 3))

    # return xy_loss + wh_loss + obj_loss + class_loss
    return xy_loss, wh_loss, obj_loss, class_loss

def broadcast_iou(box_1, box_2):
    """ 计算最终iou

    :param box_1:
    :param box_2:
    :return: [batch_size, grid, grid, anchors, num_gt_box]
    """
    # box_1: (..., (x1, y1, x2, y2))
    # box_2: (N, (x1, y1, x2, y2))

    # broadcast boxes
    box_1 = tf.expand_dims(box_1, -2)
    box_2 = tf.expand_dims(box_2, 0)
    # new_shape: (..., N, (x1, y1, x2, y2))
    new_shape = tf.broadcast_dynamic_shape(tf.shape(box_1), tf.shape(box_2))
    box_1 = tf.broadcast_to(box_1, new_shape)
    box_2 = tf.broadcast_to(box_2, new_shape)

    int_w = tf.maximum(tf.minimum(box_1[..., 2], box_2[..., 2]) -
                       tf.maximum(box_1[..., 0], box_2[..., 0]), 0)
    int_h = tf.maximum(tf.minimum(box_1[..., 3], box_2[..., 3]) -
                       tf.maximum(box_1[..., 1], box_2[..., 1]), 0)
    int_area = int_w * int_h
    box_1_area = (box_1[..., 2] - box_1[..., 0]) * \
                 (box_1[..., 3] - box_1[..., 1])
    box_2_area = (box_2[..., 2] - box_2[..., 0]) * \
                 (box_2[..., 3] - box_2[..., 1])
    return int_area / (box_1_area + box_2_area - int_area)

最后是记录训练日志

# Scalar
with summary_writer.as_default():
    tf.summary.scalar('loss/xy_loss', total_xy_loss,
                      step=epoch * train_data.total_batch_size + batch)
    # step=step)
    tf.summary.scalar('loss/wh_loss', total_wh_loss,
                      step=epoch * train_data.total_batch_size + batch)
    # step=step)
    tf.summary.scalar('loss/obj_loss', total_obj_loss,
                      step=epoch * train_data.total_batch_size + batch)
    # step=step)
    tf.summary.scalar('loss/class_loss', total_class_loss,
                      step=epoch * train_data.total_batch_size + batch)
    # step=step)
    tf.summary.scalar('loss/total_loss', total_loss,
                      step=epoch * train_data.total_batch_size + batch)
    # step=step)

# image, 只拿每个batch的第一张
# gt
gt_img = gt_imgs[0].copy() * 255
gt_boxes = gt_boxes[0] * image_shape[0]
gt_classes = gt_classes[0]
non_zero_ids = np.where(np.sum(gt_boxes, axis=-1))[0]
for i in non_zero_ids:
    label = gt_classes[i]
    class_name = classes[label]['name']
    xmin, ymin, xmax, ymax = gt_boxes[i]
    gt_img = draw_bounding_box(gt_img, class_name, label, int(xmin), int(ymin), int(xmax),
                               int(ymax))

# pred
pred_img = gt_imgs[0].copy() * 255
boxes, scores, classes, valid_detection_nums = yolo_nms(yolo_preds, 91)
# print(scores)
# print(gt_classes)
for i in range(valid_detection_nums[0]):
    if scores[0][i] > 0.5:
        label = classes[0][i].numpy()
        if classes.get(label):
            class_name = classes[label]['name']
            xmin, ymin, xmax, ymax = boxes[0][i] * image_shape[0]
            pred_img = draw_bounding_box(pred_img, class_name, scores[0][i], int(xmin), int(ymin),
                                         int(xmax), int(ymax))

concat_imgs = tf.concat([gt_img[:, :, ::-1], pred_img[:, :, ::-1]], axis=1)
summ_imgs = tf.expand_dims(concat_imgs, 0)
summ_imgs = tf.cast(summ_imgs, dtype=tf.uint8)
with summary_writer.as_default():
    tf.summary.image("imgs/gt,pred,epoch{}".format(epoch), summ_imgs,
                     step=epoch * train_data.total_batch_size + batch)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章