CVPR2020:點雲分類的自動放大框架PointAugment

CVPR2020:點雲分類的自動放大框架PointAugment

PointAugment: An Auto-Augmentation Framework for Point Cloud Classification

論文地址:

https://openaccess.thecvf.com/content_CVPR_2020/html/Li_PointAugment_An_Auto-Augmentation_Framework_for_Point_Cloud_Classification_CVPR_2020_paper.html
code:https://github.com/liruihui/PointAugment
摘要

本文提出了一種新的自動放大框架PointAugment,在訓練分類網絡時,自動優化和放大點雲樣本,以豐富數據的多樣性。與現有的二維圖像自動放大方法不同,PointAugment是一種樣本感知的方法,採用一種對抗性學習策略來聯合優化放大器網絡和分類器網絡,使放大器能夠學習產生最適合分類器的放大樣本。此外,構造了一個帶形狀變換和點位移的可學習點放大函數,並根據分類器的學習進度,精心設計損失函數來採用放大樣本。大量實驗也證實了PointAugment的有效性和魯棒性,可以改善各種網絡的形狀分類和檢索性能。

1.介紹

近24年來,人對三維神經網絡的研究興趣與日俱增。可靠地訓練網絡通常依賴於數據的可用性和多樣性。然而,與ImageNet[10]和MS-COCO數據集[15]等二維圖像基準測試不同,3D數據集的數量通常要小得多,標籤數量相對較少,多樣性有限。例如,ModelNet40[38]是3D點雲分類最常用的基準之一,只有40個類別的12311個模型。有限的數據量和多樣性可能導致過擬合問題,進而影響網絡的泛化能力。目前,數據放大(DA)是一種非常普遍的策略,通過人工增加訓練樣本的數量和多樣性來避免過度擬合,提高網絡泛化能力。對於三維點雲,由於訓練樣本數量有限,且在3D中有巨大的放大空間,傳統的DA策略[23,24]通常只是在一個小的、固定的預先定義的放大範圍內隨機擾動輸入點雲,以保持類標籤。

儘管這種傳統的DA方法對現有的分類網絡有效,但可能導致訓練不足,如下所述。首先,現有的深部三維點雲處理方法將網絡訓練和數據採集視爲兩個獨立的階段,沒有聯合優化,例如反饋訓練結果以放大DA。因此,訓練後的網絡可能是次優的。其次,現有方法對所有輸入點雲樣本應用相同的固定放大過程,包括旋轉、縮放和/或抖動。在放大過程中忽略了樣本的形狀複雜度,例如,球體無論如何旋轉都保持不變,但複雜形狀可能需要更大的旋轉。因此,傳統的DA對於增加訓練樣本可能是多餘的或不充分的[6]。

爲了改進點雲樣本的放大,提出了一種新的三維點雲自動放大框架PointAugment,並展示了其放大形狀分類的有效性;見圖1。與以前的二維圖像不同,PointAugment學習生成特定於單個樣本的放大函數。此外,可學習的放大函數同時考慮了形狀變換和點方向位移,這與三維點雲樣本的特徵有關。此外,PointAugment通過一種對抗性學習策略,將放大網絡(augmentor)與分類網絡(Classifier)進行端到端的訓練,從而與網絡訓練共同優化放大過程。通過將分類損失作爲反饋,放大器可以學習通過擴大類內數據變化來豐富輸入樣本,而分類器可以學習通過提取不敏感特徵來克服這一問題。受益於這種對抗性學習,放大器可以學習生成在不同訓練階段最適合分類者的放大樣本,從而最大限度地提高分類者的能力。作爲探索3D點雲自動放大的第一次嘗試,展示了通過用PointAugment取代傳統的DA,可以在四個具有代表性的網絡上實現對ModelNet40[38](見圖1)和SHREC16[28](見第5節)數據集的形狀分類的明顯改進,包括PointNet[23]、PointNet++[24],RSCNN[18]和DGCNN[37]。此外,還展示了PointAugment在形狀檢索上的有效性,並評估了其魯棒性、損失配置和模塊化設計。
在這裏插入圖片描述
2.相關工作

圖像數據放大

訓練數據對深層神經網絡學習執行任務起着非常重要的作用。然而,與現實世界的複雜性相比,訓練數據的數量往往是有限的,因此經常需要數據放大來放大訓練集,最大限度地利用訓練數據學習到的知識。一些工作沒有隨機變換訓練數據樣本[42,41],而是嘗試利用圖像組合[12]、生成對抗網絡[31,27]、貝葉斯優化[35]和潛在空間中的圖像插值[4,16,2]從原始數據中生成放大樣本。然而,這些方法可能產生與原始數據不同的不可靠樣本。另一方面,一些圖像DA技術[12,42,41]對具有規則結構的圖像應用像素插值,但是不能處理順序不變的點雲。另一種方法的目的是找到一個預先定義的轉換函數的最佳組合,以增加訓練樣本,而不是基於人工設計或完全隨機性應用轉換函數。

AutoAugment[3]提出了一種強化學習策略,通過交替訓練代理任務和策略控制器,然後將學習到的放大函數應用於輸入數據,從而找到最佳的放大函數集。不久之後,另兩項研究,FastAugment[14]和PBA[8]探索了先進的超參數優化方法,以更有效地找到放大的最佳轉換。與這些學習爲所有訓練樣本找到固定的放大策略的方法不同,PointAugment是樣本感知的,這意味着在訓練過程中根據單個訓練樣本的屬性和網絡能力動態生成轉換函數。最近,Tang等人[33]張等[43]建議學習使用對抗策略的目標任務的放大策略。傾向於直接最大化增加樣本的損失,以提高圖像分類網絡的泛化能力。與之不同的是,PointAugment通過一個明確設計的邊界擴大了放大後的點雲與原始點雲之間的損失;動態地調整了放大樣本的難度,以便放大的樣本能夠更好地滿足不同訓練階段的分類要求。

點雲數據放大

在現有的點處理網絡中,數據放大主要包括圍繞重力軸的隨機旋轉、隨機縮放和隨機抖動[23,24]。這些手工制定的規則在整個訓練過程中都是固定的,因此可能無法獲得最佳樣本來有效地訓練網絡。到目前爲止,還沒有發現有任何研究利用三維點雲來實現網絡學習最大化的工作。

點雲深度學習

在PointNet架構[23]的基礎上,有幾篇文章[24,17,18]探索了局部結構來放大特徵學習。另一些人通過創建局部圖[36,37,29,45]或幾何元素[11,22]來探索圖形卷積網絡。另一個工作流程[32,34,19]將不規則點投影到規則空間中,以允許傳統的卷積神經網絡工作。與上述工作不同,目標不是設計一個新的網絡,而是通過有效地優化點雲樣本的增加來提高現有網絡的分類性能。爲此,設計了一個放大器來學習一個特殊的放大函數,並根據分類器的學習進度調整放大函數。

  1. Overview

這項工作的主要貢獻是PointAugment框架,該框架自動優化輸入點雲樣本的放大,以便更有效地訓練分類網絡。圖2說明了框架的設計,有兩個深層神經網絡組件:(i)一個放大器A和(ii)一個分類器C。

在闡述PointAugment框架之前,首先討論框架背後的關鍵思想。這些都是新的想法(在以前的作品[3,14,8]中沒有出現),使能夠有效地增加訓練樣本,這些樣本現在是三維點雲,而不是二維圖像。

•樣本感知。目標是通過考慮樣本的基本幾何結構,爲每個輸入樣本回歸一個特定的放大函數,而不是爲每個輸入數據樣本找到一套通用的放大策略或過程。稱之爲樣本感知的自動放大。

•2D與3D放大。與二維圖像放大不同,三維放大涉及更廣闊和不同的空間域。應該考慮雲的兩種變形點(包括點雲的變換和點雲的變換)的放大(包括點雲的變換和點雲的變換)。

•聯合優化。在網絡訓練過程中,分類器將逐漸學習並變得更加強大,因此需要更具挑戰性的放大樣本,以便更好地訓練分類器,因爲分類器變得更強。因此,以端到端的方式設計和訓練PointAugment框架,這樣就可以共同優化放大器和分類器。爲此,必須仔細設計損失函數,動態調整增加樣本的難度,同時考慮輸入樣本和分類器的容量。

  1. Method

在本節中,首先介紹放大器和分類器的網絡架構細節。然後,提出了爲放大器和分類器制定的損失函數,並介紹了端到端訓練策略。最後,給出了實現細節。

4.1. Network Architecture

放大器。不同於現有的工作[3,14,8],放大器是樣本感知的,學習生成一個特定的函數來放大每個輸入樣本。從現在起,爲了便於閱讀,去掉了下標i,並將P表示爲A的訓練樣本輸入,P′表示A的相應放大樣本輸出。放大器的總體架構如圖3(上圖)所示。
在這裏插入圖片描述
4.2. Augmentor loss

爲了使網絡學習最大化,由放大器生成的放大樣本P′應滿足兩個要求:(i)P′比P更具挑戰性,即目標是L(P′)≥L(P);(ii)P′不應失去形狀特異性,這意味着應該描述一個與P不太遠(或不同)的形狀。爲了達到要求(i),一個簡單的方法來描述放大器(表示爲LA)的損失函數是使P和P′上的交叉熵損失之差最大化,或者等效地最小化。
在這裏插入圖片描述
4.3. Classifier loss

分類C的目標是正確預測P和P′。另外,無論輸入P或P′,C都應該具有學習穩定的每形狀全局特徵的能力。

4.4. End-to-end training strategy

算法1總結了端到端訓練策略。總的來說,在訓練過程中,該程序交替地優化和更新放大器A和分類器C中的可學習參數,同時調整另一個參數。

4.5. Implementation details

使用PyTorch[21]實現PointAugment。具體來說,將訓練階段的數量設爲S=250,批量大小爲B=24。爲了訓練放大器,採用了學習率爲0.001的Adam優化器。爲了訓練分類人員,遵循發佈的代碼和文件中各自的原始配置。具體來說,對於PointNet[23]、PointNet++[24]和RSCNN[18],使用的Adam優化器的初始學習率爲0.001,該值逐漸降低,每20個時期衰減率爲0.5。

對於DGCNN[37],使用動量爲0.9、基本學習率爲0.1的SGD解算器,該解算器使用餘弦退火策略衰減[9]。還需要注意的是,爲了減少模型振盪[5],遵循[31]使用混合訓練樣本來訓練點放大,混合訓練樣本包含一半原始訓練樣本,另一半包含先前放大的樣本,而不是隻使用原始訓練樣本。詳見[31]。此外,爲了避免過度擬合,設置了0.5的脫落概率來隨機丟棄或保持迴歸的形狀方向變換和點方向位移。在測試階段,遵循之前的網絡[23,24],將輸入的測試樣本輸入到經過訓練的分類器,以獲得預測的標籤,而不需要任何額外的計算成本。
在這裏插入圖片描述
5. Experiments

在點放大上做了大量的實驗。首先,介紹了實驗中使用的基準數據集和分類器。然後,在形狀分類和形狀檢索方面評估PointAugment。接下來,將對PointAugment的健壯性、損耗配置和模塊化設計進行詳細分析。最後,提出進一步的討論和未來可能的擴展。
在這裏插入圖片描述
在這裏插入圖片描述
在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章