點雲深度學習系列5——pointnet++文章及代碼分析

大家好。

PointNet++是PointNet的升級版本,增加了對局部信息的感知能力。體現到代碼上的話,變化還是比較多的,我們以分類爲例,對結構和代碼進行分析。

網絡結構

 

首先是網絡結構方面,複習前任PointNet網絡結構的,請點這裏

改進版去掉了T-net,在網絡層次上變多了,但是更加組織有序。

 

def get_model(point_cloud, is_training, bn_decay=None):  
    """ Classification PointNet, input is BxNx3, output Bx40 """  
    batch_size = point_cloud.get_shape()[0].value  
    num_point = point_cloud.get_shape()[1].value  
    end_points = {}  
    l0_xyz = point_cloud  
    l0_points = None  
    end_points['l0_xyz'] = l0_xyz  

    # Set abstraction layers  
    # Note: When using NCHW for layer 2, we see increased GPU memory usage (in TF1.4).  
    # So we only use NCHW for layer 1 until this issue can be resolved.  
    l1_xyz, l1_points, l1_indices = pointnet_sa_module(l0_xyz, l0_points, npoint=512, radius=0.2, nsample=32, mlp=[64,64,128], mlp2=None, group_all=False, is_training=is_training, bn_decay=bn_decay, scope='layer1', use_nchw=True)  
    l2_xyz, l2_points, l2_indices = pointnet_sa_module(l1_xyz, l1_points, npoint=128, radius=0.4, nsample=64, mlp=[128,128,256], mlp2=None, group_all=False, is_training=is_training, bn_decay=bn_decay, scope='layer2')  
    l3_xyz, l3_points, l3_indices = pointnet_sa_module(l2_xyz, l2_points, npoint=None, radius=None, nsample=None, mlp=[256,512,1024], mlp2=None, group_all=True, is_training=is_training, bn_decay=bn_decay, scope='layer3')  

    # Fully connected layers  
    net = tf.reshape(l3_points, [batch_size, -1])  
    net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training, scope='fc1', bn_decay=bn_decay)  
    net = tf_util.dropout(net, keep_prob=0.5, is_training=is_training, scope='dp1')  
    net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training, scope='fc2', bn_decay=bn_decay)  
    net = tf_util.dropout(net, keep_prob=0.5, is_training=is_training, scope='dp2')  
    net = tf_util.fully_connected(net, 40, activation_fn=None, scope='fc3')  

    return net, end_points

 

上述代碼部分依然分成特徵提取和分類任務兩個部分來看。

 

特徵提取部分即代碼中的Set abstraction layers,值得注意的是它沒有用T-net,而是直接對點雲進行處理。由三個pointnet_sa_module模塊組成,每個模塊內包含3層mlp和1個pooling層,所以共總用了9個mlp層用於特徵提取。

pointnet_sa_module模塊的代碼如下:

def pointnet_sa_module(xyz, points, npoint, radius, nsample, mlp, mlp2, group_all, is_training, bn_decay, scope, bn=True, pooling='max', knn=False, use_xyz=True, use_nchw=False):
    ''' PointNet Set Abstraction (SA) Module
        Input:
            xyz: (batch_size, ndataset, 3) TF tensor
            points: (batch_size, ndataset, channel) TF tensor
            npoint: int32 -- #points sampled in farthest point sampling中心點的個數
            radius: float32 -- search radius in local region
            nsample: int32 -- how many points in each local region
            mlp: list of int32 -- output size for MLP on each point
            mlp2: list of int32 -- output size for MLP on each region
            group_all: bool -- group all points into one PC if set true, OVERRIDE
                npoint, radius and nsample settings
            use_xyz: bool, if True concat XYZ with local point features, otherwise just use point features
            use_nchw: bool, if True, use NCHW data format for conv2d, which is usually faster than NHWC format
        Return:
            new_xyz: (batch_size, npoint, 3) TF tensor
            new_points: (batch_size, npoint, mlp[-1] or mlp2[-1]) TF tensor
            idx: (batch_size, npoint, nsample) int32 -- indices for local regions
    '''
    data_format = 'NCHW' if use_nchw else 'NHWC'
    with tf.variable_scope(scope) as sc:        # Sample and Grouping
        if group_all:
            nsample = xyz.get_shape()[1].value
            new_xyz, new_points, idx, grouped_xyz = sample_and_group_all(xyz, points, use_xyz)        else:
            new_xyz, new_points, idx, grouped_xyz = sample_and_group(npoint, radius, nsample, xyz, points, knn, use_xyz)        # Point Feature Embedding
        if use_nchw: new_points = tf.transpose(new_points, [0,3,1,2])        for i, num_out_channel in enumerate(mlp):
            new_points = tf_util.conv2d(new_points, num_out_channel, [1,1],
                                        padding='VALID', stride=[1,1],
                                        bn=bn, is_training=is_training,
                                        scope='conv%d'%(i), bn_decay=bn_decay,
                                        data_format=data_format) 
        if use_nchw: new_points = tf.transpose(new_points, [0,2,3,1])        # Pooling in Local Regions
        if pooling=='max':
            new_points = tf.reduce_max(new_points, axis=[2], keep_dims=True, name='maxpool')        elif pooling=='avg':
            new_points = tf.reduce_mean(new_points, axis=[2], keep_dims=True, name='avgpool')        elif pooling=='weighted_avg':            with tf.variable_scope('weighted_avg'):
                dists = tf.norm(grouped_xyz,axis=-1,ord=2,keep_dims=True)
                exp_dists = tf.exp(-dists * 5)
                weights = exp_dists/tf.reduce_sum(exp_dists,axis=2,keep_dims=True) # (batch_size, npoint, nsample, 1)
                new_points *= weights # (batch_size, npoint, nsample, mlp[-1])
                new_points = tf.reduce_sum(new_points, axis=2, keep_dims=True)        elif pooling=='max_and_avg':
            max_points = tf.reduce_max(new_points, axis=[2], keep_dims=True, name='maxpool')
            avg_points = tf.reduce_mean(new_points, axis=[2], keep_dims=True, name='avgpool')
            new_points = tf.concat([avg_points, max_points], axis=-1)        # [Optional] Further Processing 
        if mlp2 is not None:            if use_nchw: new_points = tf.transpose(new_points, [0,3,1,2])            for i, num_out_channel in enumerate(mlp2):
                new_points = tf_util.conv2d(new_points, num_out_channel, [1,1],
                                            padding='VALID', stride=[1,1],
                                            bn=bn, is_training=is_training,
                                            scope='conv_post_%d'%(i), bn_decay=bn_decay,
                                            data_format=data_format) 
            if use_nchw: new_points = tf.transpose(new_points, [0,2,3,1])

        new_points = tf.squeeze(new_points, [2]) # (batch_size, npoints, mlp2[-1])
        return new_xyz, new_points, idx

每個模塊中先採樣,找鄰域,然後用三層1*1卷積構成的全連接層進行特徵提取,最後做池化,輸出。

分類任務部分與PointNet差別不大,不再贅述。

 

小結

 

上述代碼是pointnet2_cls_ssg.py,它的多尺度版本爲pointnet2_cls_msg.py,單尺度版本搞清楚了,多尺度版本的理解也不成問題。

 

另外,筆者對ssg代碼測試的準確率保持在90.2%附近,始終達不到論文裏講的90.7%,與作者郵件聯繫,但是作者也僅僅把實驗條件發了一遍,和默認設置是一樣的,最終也沒有回覆更多消息了。所以結果不明。

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章