PointNet++詳解與代碼

在之前的一篇文章《PointNet:3D點集分類與分割深度學習模型》中分析了PointNet網絡是如何進行3D點雲數據分類與分割的。但是PointNet存在的一個缺點是無法獲得局部特徵,這使得它很難對複雜場景進行分析。在PointNet++中,作者通過兩個主要的方法進行了改進,使得網絡能更好的提取局部特徵。第一,利用空間距離(metric space distances),使用PointNet對點集局部區域進行特徵迭代提取,使其能夠學到局部尺度越來越大的特徵。第二,由於點集分佈很多時候是不均勻的,如果默認是均勻的,會使得網絡性能變差,所以作者提出了一種自適應密度的特徵提取方法。通過以上兩種方法,能夠更高效的學習特徵,也更有魯棒性。

目錄

1.PointNet不足之處

2. PointNet++網絡結構

2.1 Sample layer

2.2 Grouping layer

2.3 PointNet layer

2.4 點雲分佈不一致的處理方法

2.5 Point Feature Propagation for Set Segmentation

3. 參考資料


1.PointNet不足之處

在卷積神經網絡中,3D CNN和2D CNN很像,也可以通過多級學習不斷進行提取,同時也具有着卷積的平移不變性。

而在PointNet中 網絡對每一個點做低維到高維的映射進行特徵學習,然後把所有點映射到高維的特徵通過最大池化最終表示全局特徵。從本質上來說,要麼對一個點做操作,要麼對所有點做操作,實際上沒有局部的概念(loal context)同時也缺少local context 在平移不變性上也有侷限性。(世界座標系和局部座標系)。對點雲數據做平移操作後,所有的數據都將發生變化,導致所有的特徵,全局特徵都不一樣了。對於單個的物體還好,可以將其平移到座標系的中心,把他的大小歸一化到一個球中,但是在一個場景中有多個物體時則不好辦,需要對哪個物體做歸一化呢?

在PointNet++中,作者利用所在空間的距離度量將點集劃分(partition)爲有重疊的局部區域。在此基礎上,首先在小範圍中從幾何結構中提取局部特徵(淺層特徵),然後擴大範圍,在這些局部特徵的基礎上提取更高層次的特徵,直到提取到整個點集的全局特徵。可以發現,這個過程和CNN網絡的特徵提取過程類似,首先提取低級別的特徵,隨着感受野的增大,提取的特徵level越來越高

PointNet++需要解決兩個關鍵的問題:第一,如何將點集劃分爲不同的區域;第二,如何利用特徵提取器獲取不同區域的局部特徵。這兩個問題實際上是相關的,要想通過特徵提取器來對不同的區域進行特徵提取,需要每個分區具有相同的結構。這裏同樣可以類比CNN來理解,在CNN中,卷積塊作爲基本的特徵提取器,對應的區域都是(n, n)的像素區域。而在3D點集當中,同樣需要找到結構相同的子區域,和對應的區域特徵提取器。

在本文中,作者使用了PointNet作爲特徵提取器,另外一個問題就是如何來劃分點集從而產生結構相同的區域。作者使用鄰域球來定義分區,每個區域可以通過中心座標和半徑來確定。中心座標的選取,作者使用了最遠點採樣算法算法來實現(farthest point sampling (FPS) algorithm)。


2. PointNet++網絡結構

PointNet++是PointNet的延伸,在PointNet的基礎上加入了多層次結構(hierarchical structure),使得網絡能夠在越來越大的區域上提供更高級別的特徵。

網絡的每一組set abstraction layers主要包括3個部分:Sampling layer, Grouping layer and PointNet layer。

· Sample layer:主要是對輸入點進行採樣,在這些點中選出若干個中心點;
· Grouping layer:是利用上一步得到的中心點將點集劃分成若干個區域;
· PointNet layer:是對上述得到的每個區域進行編碼,變成特徵向量。

每一組提取層的輸入是(N,(d+C)),其中N是輸入點的數量,d是座標維度,C是特徵維度。輸出是(N',(d+C^{'})),其中N'是輸出點的數量,d是座標維度不變,C'是新的特徵維度。下面詳細介紹每一層的作用及實現過程。


2.1 Sample layer

使用farthest point sampling選擇N'個點,至於爲什麼選擇使用這種方法選擇點,文中提到相比於隨機採樣,這種方法能更好的的覆蓋整個點集。具體選擇多少箇中心點,數量怎麼確定,可以看做是超參數視數據規模來定。代碼爲:

def farthest_point_sample(xyz, npoint):
    """
    Input:
        xyz: pointcloud data, [B, N, 3]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [B, npoint]
    """
    device = xyz.device
    B, N, C = xyz.shape
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
    distance = torch.ones(B, N).to(device) * 1e10
    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
    batch_indices = torch.arange(B, dtype=torch.long).to(device)
    for i in range(npoint):
        centroids[:, i] = farthest
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
        dist = torch.sum((xyz - centroid) ** 2, -1)
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = torch.max(distance, -1)[1]
    return centroids

2.2 Grouping layer

這一層使用Ball query方法生成N'個局部區域,根據論文中的意思,這裏有兩個變量 ,一個是每個區域中點的數量K,另一個是球的半徑。這裏半徑應該是佔主導的,會在某個半徑的球內找點,上限是K。球的半徑和每個區域中點的數量都是超參數。代碼爲:

def query_ball_point(radius, nsample, xyz, new_xyz):
    """
    Input:
        radius: local region radius
        nsample: max sample number in local region
        xyz: all points, [B, N, 3]
        new_xyz: query points, [B, S, 3]
    Return:
        group_idx: grouped points index, [B, S, nsample]
    """
    device = xyz.device
    B, N, C = xyz.shape
    _, S, _ = new_xyz.shape
    group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
    sqrdists = square_distance(new_xyz, xyz)
    group_idx[sqrdists > radius ** 2] = N
    group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
    group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
    mask = group_idx == N
    group_idx[mask] = group_first[mask]
    return group_idx

2.3 PointNet layer

這一層是PointNet,接受N'×K×(d+C)的輸入。輸出是N'×(d+C)。需要注意的是,在輸入到網絡之前,會把該區域中的點變成圍繞中心點的相對座標。作者提到,這樣做能夠獲取點與點之間的關係。至此則完成了set abstraction工作,代碼爲:

class PointNetSetAbstraction(nn.Module):
    def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
        super(PointNetSetAbstraction, self).__init__()
        self.npoint = npoint
        self.radius = radius
        self.nsample = nsample
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        last_channel = in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
            self.mlp_bns.append(nn.BatchNorm2d(out_channel))
            last_channel = out_channel
        self.group_all = group_all

    def forward(self, xyz, points):
        """
        Input:
            xyz: input points position data, [B, C, N]
            points: input points data, [B, D, N]
        Return:
            new_xyz: sampled points position data, [B, C, S]
            new_points_concat: sample points feature data, [B, D', S]
        """
        xyz = xyz.permute(0, 2, 1)
        if points is not None:
            points = points.permute(0, 2, 1)

        if self.group_all:
            new_xyz, new_points = sample_and_group_all(xyz, points)
        else:
            new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
        # new_xyz: sampled points position data, [B, npoint, C]
        # new_points: sampled points data, [B, npoint, nsample, C+D]
        new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points =  F.relu(bn(conv(new_points)))

        new_points = torch.max(new_points, 2)[0]
        new_xyz = new_xyz.permute(0, 2, 1)
        return new_xyz, new_points

2.4 點雲分佈不一致的處理方法

點雲分佈不一致時,每個子區域中如果在分區的時候使用相同的球半徑,會導致有些稀疏區域採樣點過小

作者提到這個問題需要解決,並且提出了兩個方法:Multi-scale grouping (MSG) and Multi-resolution grouping (MRG)。下面是論文當中的示意圖。

下面分別介紹一下這兩種方法。

第一種多尺度分組(MSG),對於同一個中心點,如果使用3個不同尺度的話,就分別找圍繞每個中心點畫3個區域,每個區域的半徑及裏面的點的個數不同。對於同一個中心點來說,不同尺度的區域送入不同的PointNet進行特徵提取,之後concat,作爲這個中心點的特徵也就是說MSG實際上相當於並聯了多個hierarchical structure,每個結構中心點數量一樣,但是區域範圍不同。PointNet的輸入和輸出尺寸也不同,然後幾個不同尺度的結構在PointNet有一個Concat。代碼是:

class PointNetSetAbstractionMsg(nn.Module):
    def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
        super(PointNetSetAbstractionMsg, self).__init__()
        self.npoint = npoint
        self.radius_list = radius_list
        self.nsample_list = nsample_list
        self.conv_blocks = nn.ModuleList()
        self.bn_blocks = nn.ModuleList()
        for i in range(len(mlp_list)):
            convs = nn.ModuleList()
            bns = nn.ModuleList()
            last_channel = in_channel + 3
            for out_channel in mlp_list[i]:
                convs.append(nn.Conv2d(last_channel, out_channel, 1))
                bns.append(nn.BatchNorm2d(out_channel))
                last_channel = out_channel
            self.conv_blocks.append(convs)
            self.bn_blocks.append(bns)

    def forward(self, xyz, points):
        """
        Input:
            xyz: input points position data, [B, C, N]
            points: input points data, [B, D, N]
        Return:
            new_xyz: sampled points position data, [B, C, S]
            new_points_concat: sample points feature data, [B, D', S]
        """
        xyz = xyz.permute(0, 2, 1)
        if points is not None:
            points = points.permute(0, 2, 1)

        B, N, C = xyz.shape
        S = self.npoint
        new_xyz = index_points(xyz, farthest_point_sample(xyz, S))
        new_points_list = []
        for i, radius in enumerate(self.radius_list):
            K = self.nsample_list[i]
            group_idx = query_ball_point(radius, K, xyz, new_xyz)
            grouped_xyz = index_points(xyz, group_idx)
            grouped_xyz -= new_xyz.view(B, S, 1, C)
            if points is not None:
                grouped_points = index_points(points, group_idx)
                grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
            else:
                grouped_points = grouped_xyz

            grouped_points = grouped_points.permute(0, 3, 2, 1)  # [B, D, K, S]
            for j in range(len(self.conv_blocks[i])):
                conv = self.conv_blocks[i][j]
                bn = self.bn_blocks[i][j]
                grouped_points =  F.relu(bn(conv(grouped_points)))
            new_points = torch.max(grouped_points, 2)[0]  # [B, D', S]
            new_points_list.append(new_points)

        new_xyz = new_xyz.permute(0, 2, 1)
        new_points_concat = torch.cat(new_points_list, dim=1)
        return new_xyz, new_points_concat

另一種是多分辨率分組(MRG)。MSG很明顯會影響降低運算速度,所以提出了MRG,這種方法應該是對不同level的grouping做了一個concat,但是由於尺度不同,對於low level的先放入一個pointnet進行處理再和high level的進行concat。感覺和ResNet中的跳連接有點類似。

在這部分,作者還提到了一種random input dropout(DP)的方法,就是在輸入到點雲之前,對點集進行隨機的Dropout,比例使用了95%,也就是說進行95%的重新採樣。


2.5 Point Feature Propagation for Set Segmentation

對於點雲分割任務,我們還需要將點集上採樣回原始點集數量,這裏使用了分層的差值方法。代碼爲:

 def forward(self, xyz1, xyz2, points1, points2):
        """
        Input:
            xyz1: input points position data, [B, C, N]
            xyz2: sampled input points position data, [B, C, S]
            points1: input points data, [B, D, N]
            points2: input points data, [B, D, S]
        Return:
            new_points: upsampled points data, [B, D', N]
        """
        xyz1 = xyz1.permute(0, 2, 1)
        xyz2 = xyz2.permute(0, 2, 1)

        points2 = points2.permute(0, 2, 1)
        B, N, C = xyz1.shape
        _, S, _ = xyz2.shape

        if S == 1:
            interpolated_points = points2.repeat(1, N, 1)
        else:
            dists = square_distance(xyz1, xyz2)
            dists, idx = dists.sort(dim=-1)
            dists, idx = dists[:, :, :3], idx[:, :, :3]  # [B, N, 3]

            dist_recip = 1.0 / (dists + 1e-8)
            norm = torch.sum(dist_recip, dim=2, keepdim=True)
            weight = dist_recip / norm
            interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)

        if points1 is not None:
            points1 = points1.permute(0, 2, 1)
            new_points = torch.cat([points1, interpolated_points], dim=-1)
        else:
            new_points = interpolated_points

        new_points = new_points.permute(0, 2, 1)
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points = F.relu(bn(conv(new_points)))
        return new_points

3. 參考資料

PointNet++官網鏈接:http://stanford.edu/~rqi/pointnet2/

PointNet++代碼:https://github.com/yanx27/Pointnet_Pointnet2_pytorch

PointNet++作者視頻講解文字版:https://www.cnblogs.com/yibeimingyue/p/12002469.html

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章