計算tensor(矩陣)之間歐氏距離的方法

最近工作中需要用到矩陣中各個樣本之間歐氏距離,因此記錄一下,如何簡便快捷地進行tensor間歐氏距離的計算(使用Pytorch框架)。

按照我之前的想法,會進行兩輪或者一輪循環一個個地求出樣本間的歐氏距離,但是看過了michuanhaohao/reid-strong-baseline 中Euclidean_dist()方法的運算之後才發現了新大陸---------通過矩陣的方式快速的進行計算。

 

一、理論分析

       首先從理論上介紹 一下,矩陣之間歐氏距離的快速計算,參考了@frankzd 的博客,原文鏈接在

https://blog.csdn.net/frankzd/article/details/80251042

 

 

 

二、代碼分析

       接下來上代碼,我會在每一行進行必要的註釋(來源:https://github.com/michuanhaohao/reid-strong-baseline/blob/master/layers/triplet_loss.py

    def euclidean_dist(x, y):
        """
        Args:
          x: pytorch Variable, with shape [m, d]
          y: pytorch Variable, with shape [n, d]
        Returns:
          dist: pytorch Variable, with shape [m, n]
        """

        m, n = x.size(0), y.size(0)
        # xx經過pow()方法對每單個數據進行二次方操作後,在axis=1 方向(橫向,就是第一列向最後一列的方向)加和,此時xx的shape爲(m, 1),經過expand()方法,擴展n-1次,此時xx的shape爲(m, n)
        xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
        # yy會在最後進行轉置的操作
        yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
        dist = xx + yy
        # torch.addmm(beta=1, input, alpha=1, mat1, mat2, out=None),這行表示的意思是dist - 2 * x * yT 
        dist.addmm_(1, -2, x, y.t())
        # clamp()函數可以限定dist內元素的最大最小範圍,dist最後開方,得到樣本之間的距離矩陣
        dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability
        return dist

 

三、demo演示

       接下來用一個簡單的demo實現(也便於自己查驗最後結果是否正確)

import torch

def euclidean_dist(x, y):
    m, n = x.size(0), y.size(0)
    xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
    yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
    dist = xx + yy
    dist.addmm_(1, -2, x, y.t())
    dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability
    return dist

if __name__ == '__main__':
    x = torch.tensor([[1.0, 2.0, 3.0, 4.0], [2.0, 5.0, 7.0, 9.0]])
    y = torch.tensor([[3.0, 1.0, 2.0, 5.0], [2.0, 3.0, 4.0, 6.0]])
    dist_matrix = euclidean_dist(x, y)
    print(dist_matrix)

 最後輸出的結果爲:

tensor([[2.6458, 2.6458],[7.6158, 4.6904]])

 

       理論看起來稍微有些麻煩,不過靜下心來琢磨一下,還是很簡單的。本文使用的是pytorch下的tensor變量進行的演示,對於矩陣,原理也是相同的。學會這個方法,以後就可以很高效地,而不必通過循環的方式計算矩陣間的歐氏距離了。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章