CVPR 2018 | CPN_COCO2017姿态估计冠军解决方案

CVPR 2018 | Cascaded Pyramid Network for Multi-Person Pose Estimation
https://github.com/chenyilun95/tf-cpn

1.文章概述

本文提出了一种级联金字塔网络CPN,该网络由全局金字塔网络(GlobalNet)和利用在线难例挖掘机制的精馏网络(RefineNet)组成。GlobalNet是一个特征金字塔网络,可以成功地定位“简单”的关键点(如眼睛和手),但可能无法准确识别被遮挡或看不见的关键点。RefineNet尝试通过整合来自GlobalNet的所有尺度的特征,以及在线难例关键点挖掘损失来处理“复杂”关键点的精确定位。

如下图所示,Cascaded Pyramid Network主要由两部分组成:GlobalNet和RefineNet。

2.GlobalNet

如下图所示,GlobalNet以ResNet为基础框架,使用与FPN相似的特征金字塔结构来估计关键点。每一个特征尺度多会输出对应的关键点信息。作者称这种结构为GlobalNet。

基于ResNet主干网的GlobalNet可以有效地定位眼睛等关键点,但可能无法准确定位髋部位置。像髋部这样的关键点定位通常需要更多的上下文信息和处理,而不是附近的外观特征。在许多情况下,单凭一个Global网络很难直接识别这些关键点。基于此作者在此后接了一个RefineNet。

3.RefineNet

如下图所示,在GlobalNet生成的特征金字塔表示的基础上,作者附加了一个细化网络来处理难例关键点。为了提高信息传输的效率和保持信息传输的完整性,RefineNet将不同的层次的特征进行上采样后concat。与堆叠沙漏的细分策略不同,RefineNet将所有金字塔特性串联起来,而不是简单地使用沙漏模块末尾的上采样特性。

随着网络训练的不断深入,网络对大多数简单关键点的关注越来越多,而对被遮挡和硬关键点的关注越来越少。我们应该确保这两类关键点之间的回归平衡。因此,在RefineNet训练中,根据训练损失来明确地在线选择难例关键点,并仅从所选关键点反向传播梯度,该方法被称为OHKM。如下代码所示为OHKM损失函数,从中可以看出该函数就是对MSE输出的结果进行了排序,并筛选其中难例部分进行重点回归。

class JointsOHKMMSELoss(nn.Module):
    def __init__(self, use_target_weight, topk=8):
        super(JointsOHKMMSELoss, self).__init__()
        self.criterion = nn.MSELoss(reduction='none')
        self.use_target_weight = use_target_weight
        self.topk = topk

    def ohkm(self, loss):
        ohkm_loss = 0.
        for i in range(loss.size()[0]):
            sub_loss = loss[i]
            topk_val, topk_idx = torch.topk(
                sub_loss, k=self.topk, dim=0, sorted=False
            )
            tmp_loss = torch.gather(sub_loss, 0, topk_idx)
            ohkm_loss += torch.sum(tmp_loss) / self.topk
        ohkm_loss /= loss.size()[0]
        return ohkm_loss

    def forward(self, output, target, target_weight):
        batch_size = output.size(0)
        num_joints = output.size(1)
        heatmaps_pred = output.reshape((batch_size, num_joints, -1)).split(1, 1)
        heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)

        loss = []
        for idx in range(num_joints):
            heatmap_pred = heatmaps_pred[idx].squeeze()
            heatmap_gt = heatmaps_gt[idx].squeeze()
            if self.use_target_weight:
                loss.append(0.5 * self.criterion(
                    heatmap_pred.mul(target_weight[:, idx]),
                    heatmap_gt.mul(target_weight[:, idx])
                ))
            else:
                loss.append(
                    0.5 * self.criterion(heatmap_pred, heatmap_gt)
                )

        loss = [l.mean(dim=1).unsqueeze(dim=1) for l in loss]
        loss = torch.cat(loss, dim=1)

        return self.ohkm(loss)
4.结果展示

下图展示了不同阈值的NMS策略的性能,结果显示Soft-NMS表现出了最优性能。

下图结果显示了OHKM,在线难例挖掘的有效性。

最终的结果也显示了本文提出的策略的有效性,但总的来说本文提出的OHKM反而被其他SOTA算法广泛使用。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章