論文理解 Linkage Based Face Clustering via Graph Convolution Network
背景
其實是利用GCN進行人臉圖像聚類
論文:Linkage Based Face Clustering via Graph Convolution Network
作者提供的代碼:GCN Clustering
要解決的問題
如何在無ID且未知多少類的情況下,對人臉進行聚類
論文本身提出了將聚類問題看作是節點連接的預測問題的觀點,即如果兩張人臉圖像屬於同一個ID,則這兩張人臉圖像之間就存在連接;這裏會用到圖卷積網絡,預習請點擊鏈接
基於GCN的人臉圖像聚類
簡單點說,作者通過構造了一個GCN網絡去進行預測,具體見下圖
以上,爲了實現這個GCN,你還需要準備以下東西:
- 人臉特徵提取模型
- KNN搜索方法
- 人臉識別用的數據庫
流程上來說,作者將聚類過程劃分爲了對多個子圖(SubGraphs)進行連接預測,然後再將預測結果鏈接在一起;以中心節點以及其連接(K-Hop)構成了作者論文中提到的Instance Pivot Subgraph(IPS),然後使用GCN預測其他節點是否應該與中心節點相連。
圖卷積層
在GCN中主要是配合拉普拉斯矩陣對圖結構進行處理,論文中定義的GCN層如下所示
其中,爲特徵矩陣,激活函數使用的是ReLU,,符號是Concate操作;是圖的鄰接矩陣,,對這裏的翻譯一下,就是對鄰接矩陣的每行進行了歸一化,目的是爲了使的尺度不變;
這裏作者對預習中所提到的使用拉普拉斯矩陣進行圖卷積進行了一些輕微(點都不輕微好麼!!!)的改動。對比Symmetric normalized Laplacian定義(爲了方便理解,以下統一了符號):
可以知道,通過Renormalization Trick令,,則有圖卷積層的定義:
那麼拉普拉斯矩陣在圖卷積層中幹了什麼呢?直白點說,就是將自節點和鄰接節點做了加權平均;那麼作者做了啥改動呢?將鄰接節點做了平均並和自節點連接起來。
節點合併
這部分其實作者並沒有在論問題提太多,不過分析作者提供的代碼,應該使用了幾種方式進行嘗試,包括直接使用固定閾值連接節點後再使用寬度優先搜素得到聚類結果,不過作者給出的代碼在合併階段也採用了一些技巧,比如使用可變閾值以及最大合併數來防止聚類結果中某一類出現過大的聚類(我在使用Chinese Whisper算法進行聚類時,若直接使用固定閾值,會出現將不同人聚類到一類,然後導致這類中的樣本數量佔比超過總樣本數的50%以上,可以認爲對算法來說,這一類樣本屬於Hard Sample)
KNN搜索
過分的是,作者沒有提是如何進行KNN搜索生成圖的鄰接節點的,代碼中說可以使用任意方法。嗯,當你看到巨大的數據集以及龜速的搜索速度的時候,我決定使用Facebook的FAISS庫來加速了
MxNet復現
由於部署需要,我用mxnet復現了作者的算法 ,不是我不喜歡pytorch,真的。具體請先參考我的Azure Research目錄下的GraphGCN ,我不是在給微軟打廣告,真的。
GCN Layer
以下是MxNet中實現的GraphConv部分,其中參數A爲鄰接矩陣,x爲特徵矩陣
class GraphConv(gluon.HybridBlock):
def __init__(self, in_channels, channels):
super(GraphConv, self).__init__()
with self.name_scope():
self.weight = self.params.get('weight',shape=(channels, 2*in_channels),
init='xavier', dtype='float32', allow_deferred_init=True)
self.bias = self.params.get('bias', shape=channels, init='zeros', allow_deferred_init=True)
self.relu = nn.Activation('relu')
pass
def hybrid_forward(self, F, x, A, weight, bias):
f = F.concat(x, F.batch_dot(A, x), dim=2)
y = F.FullyConnected(data=f, weight=weight, bias=bias,
num_hidden=self.weight.shape[0], flatten=False, no_bias=False) # BNDxDF=BNF
z = self.relu(y)
return z
其他
數據的預處理部分和後處理我是用自己的庫搭建的,具體使用到的模塊包括utils裏的fast_search,numpy等,後處理部分目前全部封裝到了face裏的clusterer中,具體實現是放在ChaosMX中的gcn.cpp文件中,暫時先這樣吧,懶了~謝謝觀看!