深度學習分類、識別等任務常用的餘弦距離和對應的PyTorch代碼

餘弦距離常常在人臉識別,圖像分類,行人重識別中應用。自從centerNet可視化了softmax loss之後,人們得知神經網絡的輸出空間原來是呈現原點向外發散狀,分類結果是可以通過判斷兩個樣本在輸出空間對應的向量之間的夾角來得知是否是同一類樣本。這個夾角就是所謂的餘弦距離,夾角越小,兩個樣本越相似。

預備的數學知識

\vec{a} \times \vec{b} = |a| \cdot |b| \cos<\vec{a}, \vec{b} >

cos曲線:

比如現在有樣本A,B,對應在輸出空間的特徵向量分別是\vec{A}\vec{B}, 先對這兩個特徵值除以各自的模。

\vec{a} = \vec{A} / |A|

\vec{b} = \vec{B} / |B|

根據求向量之間夾角公式,A,B之間的角度的cos值就是:

cos<A,B> = \vec{a} \cdot \vec{b}

這個值越大,說明,向量夾角越小,說明越相似。

 

Pytorch代碼

from torch.nn import functional as F
def calculate_cos_distance(a,b):
    a = F.normalize(a, dim=-1)
    b = F.normalize(b, dim=-1)
    cose = torch.mm(a,b)
    return 1 - cose

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章