結合代碼理解Pointnet++網絡結構

前言

Pointnet提取的全局特徵能夠很好地完成分類任務,由於網絡將所有的點最大池化爲了一個全局特徵,因此局部點與點之間的聯繫並沒有被網絡學習到,導致網絡的輸出缺乏點雲的局部結構特徵,因此PointNet對於場景的分割效果十分一般。在點雲分類和物體的Part Segmentation中,這樣的問題可以通過中心化物體的座標軸部分地解決,但在場景分割中很難去解決。
原文地址:https://arxiv.org/abs/1706.02413

因此作者在此基礎上又提出了能夠實現點雲作多層特徵提取的Pointnet++網絡,網絡結構如下:
在這裏插入圖片描述
圖片來源:https://arxiv.org/abs/1706.02413

網絡的基本組成

下面介紹上圖中的網絡設計,傳統的CNN在進行特徵學習時,使用卷積核作爲局部感受野,每層的卷積核共享權值,進過多層的特徵學習,最後的輸出會包含圖像的局部特徵信息。通過改變中借鑑CNN的採樣思路,採取分層特徵學習,即在小區域中使用點雲採樣+成組+提取局部特徵(S+G+P)的方式,包含這三部分的機構稱爲Set Abstraction

  • Sampling:隨機選擇一個初始點,然後依次利用FPS(最遠點採樣)進行採樣,直到達到目標點數;
  • Grouping:以採樣點爲中心,利用Ball Query劃一個R爲半徑的球,將裏面包含的點雲作爲一簇成組;
  • Pointnet: 對Sampling+Grouping以後的點雲進行局部的全局特徵提取。

以2D點圖爲例,整個SA(Set Abstraction)三步的實現過程表示如下:
在這裏插入圖片描述
在這裏插入圖片描述圖片來源:https://arxiv.org/abs/1706.02413

每層新的中心點都是從上一層抽取的特徵子集,中心點的個數就是成組的點集數,隨着層數增加,中心點的個數也會逐漸降低,抽取到點雲的局部結構特徵。

針對非均勻點雲情況

當點雲不均勻時,每個子區域中如果在分區的時候使用相同的球半徑,會導致部分稀疏區域採樣點過小。

文中提出**多尺度成組 (MSG)多分辨率成組 (MRG)**兩種解決辦法。
在這裏插入圖片描述

簡單概括這兩種採樣方法:

  • **多尺度成組(MSG):**對於選取的一箇中心點設置多個半徑進行成組,並將經過PointNet對每個區域抽取後的特徵進行拼接(concat)來當做該中心點的特徵,個人認爲這種做法會產生很多特徵重疊,結果會可以保留和突出(邊際疊加)更多局部關鍵的特徵,但是這種方式不同範圍內計算的權值卻很難共享,計算量會變大很多。
  • **多分辨率成組(MRG):**對不同特徵層上(分辨率)提取的特徵再進行concat,以上圖右圖爲例,最後的concat包含左右兩個部分特徵,分別來自底層和高層的特徵抽取,對於low level點雲成組後經過一個pointnet和high level的進行concat,思想是特徵的抽取中的跳層連接。當局部點雲區域較稀疏時,上層提取到的特徵可靠性可能比底層更差,因此考慮對底層特徵提升權重。當然,點雲密度較高時能夠提取到的特徵也會更多。這種方法優化了直接在稀疏點雲上進行特徵抽取產生的問題,且相對於MSG的效率也較高。

在該網絡中作者使用了對輸入點雲進行隨機採樣(丟棄)random input dropout(DP)方法。Dropout的設計本身是爲了降低過擬合,增強模型的魯棒性,結果顯示對於分類任務的效果也有不錯的提升,作者給了一個對比圖:
在這裏插入圖片描述
本文中使用的縮寫說明:

  • SA:set abstraction 點集抽取模塊
  • FC:fully connected layers 全連接層
  • FP:feature
    propagation 特徵傳播模塊(跨層連接,多個全連接)

SA模塊的代碼實現

  • utils/pointnet_util.py/ 中採樣成組的代碼具體實現。
def sample_and_group(npoint, radius, nsample, xyz, points, knn=False, use_xyz=True):
    '''
    輸入參數說明:
    Input:
        npoint: int32,中心點的數量(分組數)
        radius: float32,ball quary的球半徑大小
        nsample: int32,區域內採樣到的點數
        xyz: (batch_size, ndataset, 3) TF tensor,例如:分類任務起始值(32,1024,3)
        points: (batch_size, ndataset, channel) TF tensor, 如果爲None則等於xyz
        knn: bool, True表示使用KNN方法採樣,否則使用球半徑搜索
        use_xyz: bool, True 表示抽取的局部點的特徵與xyz進行concat, 否則不進行,默認爲True
        
    輸出參數說明:
    Output:
        new_xyz: (batch_size, npoint, 3) TF tensor
        new_points: (batch_size, npoint, nsample, 3+channel) TF tensor,點的特徵進行了concat
        idx: (batch_size, npoint, nsample) TF tensor, 採樣的局部區域內點的索引值
        grouped_xyz: (batch_size, npoint, nsample, 3) TF tensor, 通過減去xyz對點進行區域歸一化
        注:源碼中沒有tf_ops/grouping和sampling/下沒有放編譯生成對應的鏈接庫.so文件,可能要重新編譯才能執行相應的py腳本
    '''
	#1.對原始點雲輸入進行採樣和分組
    new_xyz = gather_point(xyz, farthest_point_sample(npoint, xyz)) # (batch_size, npoint, 3)
    if knn:
        _,idx = knn_point(nsample, xyz, new_xyz)
    else:
        idx, pts_cnt = query_ball_point(radius, nsample, xyz, new_xyz)
    grouped_xyz = group_point(xyz, idx) # (batch_size, npoint, nsample, 3)
    grouped_xyz -= tf.tile(tf.expand_dims(new_xyz, 2), [1,1,nsample,1]) # translation normalization,減去中心點座標進行區域座標歸一化
    
    #2.對高層次特徵進行分組
    if points is not None:
        grouped_points = group_point(points, idx) # (batch_size, npoint, nsample, channel)
        if use_xyz:
            new_points = tf.concat([grouped_xyz, grouped_points], axis=-1) # (batch_size, npoint, nample, 3+channel)
        else:
            new_points = grouped_points
    else:
        new_points = grouped_xyz

    return new_xyz, new_points, idx, grouped_xyz

#在最後一次SA操作中,需要對全部特徵進行採樣分組
def sample_and_group_all(xyz, points, use_xyz=True):
    '''
    #輸出變爲三個參數,功能同上
    Inputs:
        xyz: (batch_size, ndataset, 3) TF tensor
        points: (batch_size, ndataset, channel) TF tensor
        use_xyz: bool
    輸出:
    Outputs:
        new_xyz: (batch_size, 1, 3) as (0,0,0)
        new_points: (batch_size, 1, ndataset, 3+channel) TF tensor
    Note:
       等價於sample_and_group(npoint=1, radius=inf)以(0,0,0)爲重心
    '''
    batch_size = xyz.get_shape()[0].value
    nsample = xyz.get_shape()[1].value
    new_xyz = tf.constant(np.tile(np.array([0,0,0]).reshape((1,1,3)), (batch_size,1,1)),dtype=tf.float32) # (batch_size, 1, 3)
    idx = tf.constant(np.tile(np.array(range(nsample)).reshape((1,1,nsample)), (batch_size,1,1)))
    grouped_xyz = tf.reshape(xyz, (batch_size, 1, nsample, 3)) # (batch_size, npoint=1, nsample, 3)
    if points is not None:
        if use_xyz:
            new_points = tf.concat([xyz, points], axis=2) # (batch_size, 16, 259)
        else:
            new_points = points
        new_points = tf.expand_dims(new_points, 1) # (batch_size, 1, 16, 259)
    else:
        new_points = grouped_xyz
    return new_xyz, new_points, idx, grouped_xyz


def pointnet_sa_module(xyz, points, npoint, radius, nsample, mlp, mlp2, group_all, is_training, bn_decay, scope, bn=True, pooling='max', knn=False, use_xyz=True, use_nchw=False):
    ''' PointNet Set Abstraction (SA) Module
        Input:
            xyz: (batch_size, ndataset, 3) TF tensor
            points: (batch_size, ndataset, channel) TF tensor
            npoint: int32 -- 最遠點採樣點數(中心點數/成組數)
            radius: float32 -- 局部區域的搜索半徑
            nsample: int32 -- 每個區域內的採樣點數
            mlp: list of int32 -- 對每個點進行MLP的網絡(輸出)大小
            mlp2: list of int32 -- 對每個區域進行MLP的網絡(輸出)大小
            group_all: bool -- 如果爲True,則重寫npoint, radius and nsample爲None
            use_xyz: bool, True 表示抽取的局部點的特徵與xyz進行concat, 否則不進行
            use_nchw: bool, True, 使用NCHW點雲數據格式進行卷積, 作者指出這樣比NHWC格式的計算更快
        Return:
            new_xyz: (batch_size, npoint, 3) TF tensor
            new_points: (batch_size, npoint, mlp[-1] or mlp2[-1]) TF tensor
            idx: (batch_size, npoint, nsample) int32 -- 區域索引
    '''
    data_format = 'NCHW' if use_nchw else 'NHWC'
    with tf.variable_scope(scope) as sc:
        # Sample and Grouping
        if group_all:
            nsample = xyz.get_shape()[1].value
            new_xyz, new_points, idx, grouped_xyz = sample_and_group_all(xyz, points, use_xyz)
        else:
            new_xyz, new_points, idx, grouped_xyz = sample_and_group(npoint, radius, nsample, xyz, points, knn, use_xyz)

        # Point Feature Embedding
        if use_nchw: new_points = tf.transpose(new_points, [0,3,1,2])#nchw->nwch
        for i, num_out_channel in enumerate(mlp):
            new_points = tf_util.conv2d(new_points, num_out_channel, [1,1],
                                        padding='VALID', stride=[1,1],
                                        bn=bn, is_training=is_training,
                                        scope='conv%d'%(i), bn_decay=bn_decay,
                                        data_format=data_format) 
        if use_nchw: new_points = tf.transpose(new_points, [0,2,3,1])#nchw->nhwc
    """
    省略 some code(區域max pooling)
    """
    
    
#針對稀疏點雲加入多尺度採樣(msg)
def pointnet_sa_module_msg(xyz, points, npoint, radius_list, nsample_list, mlp_list, is_training, bn_decay, scope, bn=True, use_xyz=True, use_nchw=False):
    ''' PointNet Set Abstraction (SA) module with Multi-Scale Grouping (MSG)
        Input:
            xyz: (batch_size, ndataset, 3) TF tensor
            points: (batch_size, ndataset, channel) TF tensor
            npoint: int32 -- #points sampled in farthest point sampling
            radius: list of float32 -- search radius in local region
            nsample: list of int32 -- how many points in each local region
            mlp: list of list of int32 -- output size for MLP on each point
            use_xyz: bool, if True concat XYZ with local point features, otherwise just use point features
            use_nchw: bool, if True, use NCHW data format for conv2d, which is usually faster than NHWC format
        Return:
            new_xyz: (batch_size, npoint, 3) TF tensor
            new_points: (batch_size, npoint, \sum_k{mlp[k][-1]}) TF tensor
    '''
    data_format = 'NCHW' if use_nchw else 'NHWC'
    with tf.variable_scope(scope) as sc:
        new_xyz = gather_point(xyz, farthest_point_sample(npoint, xyz))
        new_points_list = []
        for i in range(len(radius_list)):
            radius = radius_list[i]
            nsample = nsample_list[i]
            idx, pts_cnt = query_ball_point(radius, nsample, xyz, new_xyz)
            grouped_xyz = group_point(xyz, idx)
            grouped_xyz -= tf.tile(tf.expand_dims(new_xyz, 2), [1,1,nsample,1])
            if points is not None:
                grouped_points = group_point(points, idx)
                if use_xyz:
                    grouped_points = tf.concat([grouped_points, grouped_xyz], axis=-1)
            else:
                grouped_points = grouped_xyz
            if use_nchw: grouped_points = tf.transpose(grouped_points, [0,3,1,2])
            for j,num_out_channel in enumerate(mlp_list[i]):
                grouped_points = tf_util.conv2d(grouped_points, num_out_channel, [1,1],
                                                padding='VALID', stride=[1,1], bn=bn, is_training=is_training,
                                                scope='conv%d_%d'%(i,j), bn_decay=bn_decay)
            if use_nchw: grouped_points = tf.transpose(grouped_points, [0,2,3,1])
            new_points = tf.reduce_max(grouped_points, axis=[2])
            new_points_list.append(new_points)
        new_points_concat = tf.concat(new_points_list, axis=-1)
        return new_xyz, new_points_concat


def pointnet_fp_module(xyz1, xyz2, points1, points2, mlp, is_training, bn_decay, scope, bn=True):
    ''' PointNet Feature Propogation (FP) Module
    	FP層,作用是更新從插值操作和跳層連接合並來的特徵
        Input:                                                                                                      
            xyz1: (batch_size, ndataset1, 3) TF tensor                                                              
            xyz2: (batch_size, ndataset2, 3) TF tensor, sparser than xyz1                                           
            points1: (batch_size, ndataset1, nchannel1) TF tensor                                                   
            points2: (batch_size, ndataset2, nchannel2) TF tensor
            mlp: list of int32 --對給個點進行mlp後的輸出特徵維度大小                                                 
        Return:
            new_points: (batch_size, ndataset1, mlp[-1]) TF tensor
            注:這一部分會用到插值模塊,源碼中帶有tf_ops/3d_interpolation/tf_interpolate_so.so文件可以使用,不用重新編譯。不同於需要進行編譯的grouping和sampling操作。
    '''
    with tf.variable_scope(scope) as sc:
        dist, idx = three_nn(xyz1, xyz2)
        dist = tf.maximum(dist, 1e-10)
        norm = tf.reduce_sum((1.0/dist),axis=2,keep_dims=True)
        norm = tf.tile(norm,[1,1,3])
        weight = (1.0/dist) / norm
        interpolated_points = three_interpolate(points2, idx, weight)

        if points1 is not None:
            new_points1 = tf.concat(axis=2, values=[interpolated_points, points1]) # B,ndataset1,nchannel1+nchannel2
        else:
            new_points1 = interpolated_points
        new_points1 = tf.expand_dims(new_points1, 2)
        for i, num_out_channel in enumerate(mlp):
            new_points1 = tf_util.conv2d(new_points1, num_out_channel, [1,1],
                                         padding='VALID', stride=[1,1],
                                         bn=bn, is_training=is_training,
                                         scope='conv_%d'%(i), bn_decay=bn_decay)
        new_points1 = tf.squeeze(new_points1, [2]) # B,ndataset1,mlp[-1]
        return new_points1

以上是SA和FP部分的代碼實現,接下來對分類任務的代碼進行解讀。

單尺度成組(SSG)分類網絡的實現

以最基礎的單尺度採樣分組設計爲例,結合代碼瞭解模型的搭建過程。

  • models/pointnet2_cls_ssg.py /
def get_model(point_cloud, is_training, bn_decay=None):
    """ Classification PointNet, input is BxNx3, output Bx40 """
    batch_size = point_cloud.get_shape()[0].value
    num_point = point_cloud.get_shape()[1].value
    end_points = {}
    l0_xyz = point_cloud
    l0_points = None
    end_points['l0_xyz'] = l0_xyz
    # Set abstraction layers
    # Note: When using NCHW for layer 2, we see increased GPU memory usage (in TF1.4).
    # So we only use NCHW for layer 1 until this issue can be resolved.
    """
    調用三次SA模塊+三次全連接層+兩次dropout=0.5,和PointNet一樣,除最後一層外,在所有的全連接層後都會進行批量歸一化操作+ReLU操作:
    SA(512, 0.2, [64, 64, 128]) → SA(128, 0.4, [128, 128, 256]) → SA([256, 512, 1024]) →
FC(512, 0.5) → FC(256, 0.5) → FC(K)
    """
    l1_xyz, l1_points, l1_indices = pointnet_sa_module(l0_xyz, l0_points, npoint=512, radius=0.2, nsample=32, mlp=[64,64,128], mlp2=None, group_all=False, is_training=is_training, bn_decay=bn_decay, scope='layer1', use_nchw=True)
    l2_xyz, l2_points, l2_indices = pointnet_sa_module(l1_xyz, l1_points, npoint=128, radius=0.4, nsample=64, mlp=[128,128,256], mlp2=None, group_all=False, is_training=is_training, bn_decay=bn_decay, scope='layer2')
    l3_xyz, l3_points, l3_indices = pointnet_sa_module(l2_xyz, l2_points, npoint=None, radius=None, nsample=None, mlp=[256,512,1024], mlp2=None, group_all=True, is_training=is_training, bn_decay=bn_decay, scope='layer3')
     # Fully connected layers
    net = tf.reshape(l3_points, [batch_size, -1])
    net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training, scope='fc1', bn_decay=bn_decay)
    net = tf_util.dropout(net, keep_prob=0.5, is_training=is_training, scope='dp1')
    net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training, scope='fc2', bn_decay=bn_decay)
    net = tf_util.dropout(net, keep_prob=0.5, is_training=is_training, scope='dp2')
    net = tf_util.fully_connected(net, 40, activation_fn=None, scope='fc3')

    return net, end_points
    
    """
    對於多尺度的分類網絡模型(MSG)對應於pointnet2_cls_msg.py,這裏的半徑和mlp維度都分別轉變爲向量和數組表示形式,整體的計算過程如下:    
SA(512, [0.1, 0.2, 0.4], [[32, 32, 64], [64, 64, 128], [64, 96, 128]]) →
SA(128, [0.2, 0.4, 0.8], [[64, 64, 128], [128, 128, 256], [128, 128, 256]]) →
SA([256, 512, 1024]) → F C(512, 0.5) → F C(256, 0.5) → F C(K)
    對於多分辨率分類模型(MRG),作者在附錄中只是給出了設計的步驟,實現源碼沒有給出
    """

文章給出了針對ModelNet40S數據集上的分割模型的效果比較:
在這裏插入圖片描述
相比於Pointnet的結果,Pointnet++在此有小幅度的提升。

對於分割部分,會單獨進行一次總結,文中給出的分割效果對比圖:
在這裏插入圖片描述
結果顯示在場景分割網絡中,準確度關係爲:MSG+DP > MRG+DP > SSG> PointNet

源碼其餘部分的介紹不詳細展開,根據個人理解將源碼的結構與功能設計展示如下:
在這裏插入圖片描述

結語

本文主要結合代碼層面總結了pointnet++網絡設計以及分類任務的實現。重點理解pointnet++是如何利用set abstraction(SA)這種結構學習到局部結構上的特徵,並通過跳步連接和多尺度採樣(MSG+DP)來提高模型對點雲的分割準確性。可以注意到pointnet++中在特徵提取時使用pointnet網絡,但是最後的結果的魯棒性在不添加其他設計的情況下沒有原網絡好,並且作者沒有繼續使用T-net進行點雲對齊的方法。
博客內容有很多理解不足之處,請多多交流指正,接下來將在此基礎上繼續進行相關論文的學習。

參考源碼地址:
1.原論文實現代碼
https://github.com/charlesq34/pointnet2
2.基於pytorch實現:
https://github.com/erikwijmans/Pointnet2_PyTorch
https://github.com/yanx27/Pointnet_Pointnet2_pytorch

推薦免費又認真的知識星球 — 點雲PCL,分享計算機視覺領域相關資料,點雲分析的熱門論文,個人學習計劃等,自己學不動了可以進去交流分享。
在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章