在之前的一篇文章《PointNet:3D點集分類與分割深度學習模型》中分析了PointNet網絡是如何進行3D點雲數據分類與分割的。但是PointNet存在的一個缺點是無法獲得局部特徵,這使得它很難對複雜場景進行分析。在PointNet++中,作者通過兩個主要的方法進行了改進,使得網絡能更好的提取局部特徵。第一,利用空間距離(metric space distances),使用PointNet對點集局部區域進行特徵迭代提取,使其能夠學到局部尺度越來越大的特徵。第二,由於點集分佈很多時候是不均勻的,如果默認是均勻的,會使得網絡性能變差,所以作者提出了一種自適應密度的特徵提取方法。通過以上兩種方法,能夠更高效的學習特徵,也更有魯棒性。
目錄
2.5 Point Feature Propagation for Set Segmentation
1.PointNet不足之處
在卷積神經網絡中,3D CNN和2D CNN很像,也可以通過多級學習不斷進行提取,同時也具有着卷積的平移不變性。
而在PointNet中 網絡對每一個點做低維到高維的映射進行特徵學習,然後把所有點映射到高維的特徵通過最大池化最終表示全局特徵。從本質上來說,要麼對一個點做操作,要麼對所有點做操作,實際上沒有局部的概念(loal context) 。同時也缺少local context 在平移不變性上也有侷限性。(世界座標系和局部座標系)。對點雲數據做平移操作後,所有的數據都將發生變化,導致所有的特徵,全局特徵都不一樣了。對於單個的物體還好,可以將其平移到座標系的中心,把他的大小歸一化到一個球中,但是在一個場景中有多個物體時則不好辦,需要對哪個物體做歸一化呢?
在PointNet++中,作者利用所在空間的距離度量將點集劃分(partition)爲有重疊的局部區域。在此基礎上,首先在小範圍中從幾何結構中提取局部特徵(淺層特徵),然後擴大範圍,在這些局部特徵的基礎上提取更高層次的特徵,直到提取到整個點集的全局特徵。可以發現,這個過程和CNN網絡的特徵提取過程類似,首先提取低級別的特徵,隨着感受野的增大,提取的特徵level越來越高。
PointNet++需要解決兩個關鍵的問題:第一,如何將點集劃分爲不同的區域;第二,如何利用特徵提取器獲取不同區域的局部特徵。這兩個問題實際上是相關的,要想通過特徵提取器來對不同的區域進行特徵提取,需要每個分區具有相同的結構。這裏同樣可以類比CNN來理解,在CNN中,卷積塊作爲基本的特徵提取器,對應的區域都是(n, n)的像素區域。而在3D點集當中,同樣需要找到結構相同的子區域,和對應的區域特徵提取器。
在本文中,作者使用了PointNet作爲特徵提取器,另外一個問題就是如何來劃分點集從而產生結構相同的區域。作者使用鄰域球來定義分區,每個區域可以通過中心座標和半徑來確定。中心座標的選取,作者使用了最遠點採樣算法算法來實現(farthest point sampling (FPS) algorithm)。
2. PointNet++網絡結構
PointNet++是PointNet的延伸,在PointNet的基礎上加入了多層次結構(hierarchical structure),使得網絡能夠在越來越大的區域上提供更高級別的特徵。
網絡的每一組set abstraction layers主要包括3個部分:Sampling layer, Grouping layer and PointNet layer。
· Sample layer:主要是對輸入點進行採樣,在這些點中選出若干個中心點;
· Grouping layer:是利用上一步得到的中心點將點集劃分成若干個區域;
· PointNet layer:是對上述得到的每個區域進行編碼,變成特徵向量。
每一組提取層的輸入是,其中N是輸入點的數量,d是座標維度,C是特徵維度。輸出是,其中N'是輸出點的數量,d是座標維度不變,C'是新的特徵維度。下面詳細介紹每一層的作用及實現過程。
2.1 Sample layer
使用farthest point sampling選擇N'個點,至於爲什麼選擇使用這種方法選擇點,文中提到相比於隨機採樣,這種方法能更好的的覆蓋整個點集。具體選擇多少箇中心點,數量怎麼確定,可以看做是超參數視數據規模來定。代碼爲:
def farthest_point_sample(xyz, npoint):
"""
Input:
xyz: pointcloud data, [B, N, 3]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [B, npoint]
"""
device = xyz.device
B, N, C = xyz.shape
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
distance = torch.ones(B, N).to(device) * 1e10
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
batch_indices = torch.arange(B, dtype=torch.long).to(device)
for i in range(npoint):
centroids[:, i] = farthest
centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
dist = torch.sum((xyz - centroid) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = torch.max(distance, -1)[1]
return centroids
2.2 Grouping layer
這一層使用Ball query方法生成N'個局部區域,根據論文中的意思,這裏有兩個變量 ,一個是每個區域中點的數量K,另一個是球的半徑。這裏半徑應該是佔主導的,會在某個半徑的球內找點,上限是K。球的半徑和每個區域中點的數量都是超參數。代碼爲:
def query_ball_point(radius, nsample, xyz, new_xyz):
"""
Input:
radius: local region radius
nsample: max sample number in local region
xyz: all points, [B, N, 3]
new_xyz: query points, [B, S, 3]
Return:
group_idx: grouped points index, [B, S, nsample]
"""
device = xyz.device
B, N, C = xyz.shape
_, S, _ = new_xyz.shape
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
sqrdists = square_distance(new_xyz, xyz)
group_idx[sqrdists > radius ** 2] = N
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
mask = group_idx == N
group_idx[mask] = group_first[mask]
return group_idx
2.3 PointNet layer
這一層是PointNet,接受N'×K×(d+C)的輸入。輸出是N'×(d+C)。需要注意的是,在輸入到網絡之前,會把該區域中的點變成圍繞中心點的相對座標。作者提到,這樣做能夠獲取點與點之間的關係。至此則完成了set abstraction工作,代碼爲:
class PointNetSetAbstraction(nn.Module):
def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
super(PointNetSetAbstraction, self).__init__()
self.npoint = npoint
self.radius = radius
self.nsample = nsample
self.mlp_convs = nn.ModuleList()
self.mlp_bns = nn.ModuleList()
last_channel = in_channel
for out_channel in mlp:
self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
last_channel = out_channel
self.group_all = group_all
def forward(self, xyz, points):
"""
Input:
xyz: input points position data, [B, C, N]
points: input points data, [B, D, N]
Return:
new_xyz: sampled points position data, [B, C, S]
new_points_concat: sample points feature data, [B, D', S]
"""
xyz = xyz.permute(0, 2, 1)
if points is not None:
points = points.permute(0, 2, 1)
if self.group_all:
new_xyz, new_points = sample_and_group_all(xyz, points)
else:
new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
# new_xyz: sampled points position data, [B, npoint, C]
# new_points: sampled points data, [B, npoint, nsample, C+D]
new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
for i, conv in enumerate(self.mlp_convs):
bn = self.mlp_bns[i]
new_points = F.relu(bn(conv(new_points)))
new_points = torch.max(new_points, 2)[0]
new_xyz = new_xyz.permute(0, 2, 1)
return new_xyz, new_points
2.4 點雲分佈不一致的處理方法
點雲分佈不一致時,每個子區域中如果在分區的時候使用相同的球半徑,會導致有些稀疏區域採樣點過小。
作者提到這個問題需要解決,並且提出了兩個方法:Multi-scale grouping (MSG) and Multi-resolution grouping (MRG)。下面是論文當中的示意圖。
下面分別介紹一下這兩種方法。
第一種多尺度分組(MSG),對於同一個中心點,如果使用3個不同尺度的話,就分別找圍繞每個中心點畫3個區域,每個區域的半徑及裏面的點的個數不同。對於同一個中心點來說,不同尺度的區域送入不同的PointNet進行特徵提取,之後concat,作爲這個中心點的特徵。也就是說MSG實際上相當於並聯了多個hierarchical structure,每個結構中心點數量一樣,但是區域範圍不同。PointNet的輸入和輸出尺寸也不同,然後幾個不同尺度的結構在PointNet有一個Concat。代碼是:
class PointNetSetAbstractionMsg(nn.Module):
def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
super(PointNetSetAbstractionMsg, self).__init__()
self.npoint = npoint
self.radius_list = radius_list
self.nsample_list = nsample_list
self.conv_blocks = nn.ModuleList()
self.bn_blocks = nn.ModuleList()
for i in range(len(mlp_list)):
convs = nn.ModuleList()
bns = nn.ModuleList()
last_channel = in_channel + 3
for out_channel in mlp_list[i]:
convs.append(nn.Conv2d(last_channel, out_channel, 1))
bns.append(nn.BatchNorm2d(out_channel))
last_channel = out_channel
self.conv_blocks.append(convs)
self.bn_blocks.append(bns)
def forward(self, xyz, points):
"""
Input:
xyz: input points position data, [B, C, N]
points: input points data, [B, D, N]
Return:
new_xyz: sampled points position data, [B, C, S]
new_points_concat: sample points feature data, [B, D', S]
"""
xyz = xyz.permute(0, 2, 1)
if points is not None:
points = points.permute(0, 2, 1)
B, N, C = xyz.shape
S = self.npoint
new_xyz = index_points(xyz, farthest_point_sample(xyz, S))
new_points_list = []
for i, radius in enumerate(self.radius_list):
K = self.nsample_list[i]
group_idx = query_ball_point(radius, K, xyz, new_xyz)
grouped_xyz = index_points(xyz, group_idx)
grouped_xyz -= new_xyz.view(B, S, 1, C)
if points is not None:
grouped_points = index_points(points, group_idx)
grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
else:
grouped_points = grouped_xyz
grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S]
for j in range(len(self.conv_blocks[i])):
conv = self.conv_blocks[i][j]
bn = self.bn_blocks[i][j]
grouped_points = F.relu(bn(conv(grouped_points)))
new_points = torch.max(grouped_points, 2)[0] # [B, D', S]
new_points_list.append(new_points)
new_xyz = new_xyz.permute(0, 2, 1)
new_points_concat = torch.cat(new_points_list, dim=1)
return new_xyz, new_points_concat
另一種是多分辨率分組(MRG)。MSG很明顯會影響降低運算速度,所以提出了MRG,這種方法應該是對不同level的grouping做了一個concat,但是由於尺度不同,對於low level的先放入一個pointnet進行處理再和high level的進行concat。感覺和ResNet中的跳連接有點類似。
在這部分,作者還提到了一種random input dropout(DP)的方法,就是在輸入到點雲之前,對點集進行隨機的Dropout,比例使用了95%,也就是說進行95%的重新採樣。
2.5 Point Feature Propagation for Set Segmentation
對於點雲分割任務,我們還需要將點集上採樣回原始點集數量,這裏使用了分層的差值方法。代碼爲:
def forward(self, xyz1, xyz2, points1, points2):
"""
Input:
xyz1: input points position data, [B, C, N]
xyz2: sampled input points position data, [B, C, S]
points1: input points data, [B, D, N]
points2: input points data, [B, D, S]
Return:
new_points: upsampled points data, [B, D', N]
"""
xyz1 = xyz1.permute(0, 2, 1)
xyz2 = xyz2.permute(0, 2, 1)
points2 = points2.permute(0, 2, 1)
B, N, C = xyz1.shape
_, S, _ = xyz2.shape
if S == 1:
interpolated_points = points2.repeat(1, N, 1)
else:
dists = square_distance(xyz1, xyz2)
dists, idx = dists.sort(dim=-1)
dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
dist_recip = 1.0 / (dists + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm
interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
if points1 is not None:
points1 = points1.permute(0, 2, 1)
new_points = torch.cat([points1, interpolated_points], dim=-1)
else:
new_points = interpolated_points
new_points = new_points.permute(0, 2, 1)
for i, conv in enumerate(self.mlp_convs):
bn = self.mlp_bns[i]
new_points = F.relu(bn(conv(new_points)))
return new_points
3. 參考資料
PointNet++官網鏈接:http://stanford.edu/~rqi/pointnet2/
PointNet++代碼:https://github.com/yanx27/Pointnet_Pointnet2_pytorch
PointNet++作者視頻講解文字版:https://www.cnblogs.com/yibeimingyue/p/12002469.html