論文筆記:TrafficPredict: Trajectory Prediction for Heterogeneous Traffic-Agents
摘要
這是百度在AAAI2019發佈的一篇文章。這篇文章提出了一種基於4D-graph的方法實現複雜場景下的軌跡預測,研究對象包含行人、機動車和自行車。
實現方法
本文提出了一個基於LSTM的算法,名爲 TrafficPredict 。構建了一個4D Graph,輸入是軌跡序列數據,4D graph的兩個維度是個體間的交互,一個維度是時間序列,另一個維度是分類。graph中每個個體都是一個節點,每個類別(總共三個類別)表示爲一個超節點,節點間的關係用邊表示(邊包括同時刻個體間、同時刻類別間(即超節點的連接)、同時刻個體與對應類別、相鄰時刻同一個體與自身、相鄰時刻同一類別與自身)。
因爲同類物體的移動速度、個體間交互方式比較接近,因此這種同時提取類別運動特徵和個體運動特徵的方法能取得一個更好的結果。
實例(個體)層 :Instance Layer
Instance Layer 用於捕獲交通中每個個體的移動模式。
- 同類instance共享相同的權值(這裏指的是temporal edges);
- 總共有3個類別,因此有3個不同的LSTM(同樣指的是temporal edges);
- 所有類別共享spatial edges的權值(即個體與個體間的相互作用);
類別層:Category Layer
用於學習相同類別個體的移動模式,從而更好地預測每個個體的軌跡。類別層包括四個部分:超節點(代表一個特定的類別)、個體與對應超節點之間的邊、相鄰時刻超節點與自己的temporal edges。第一步先將instance LSTM提取到的特徵作爲category layer的輸入,以便讓類別層提取出同一類別的個體的運動模式;之後將category layer的輸出反作用於instance layer,從而改善instance layer的預測結果。
預測評估
假設交通參與者(行人、自行車、汽車)的位置服從雙變量高斯分佈(bivariate Gaussian distribution)
網絡用於預測這些參數:
損失函數構建如下:
數據集
除了算法之外,文中還提到了百度發佈的數據集Apollo。數據集包含155分鐘10FPS的數據集,包含豐富的軌跡信息、3D檢測框、雷達點雲數據。