pytorch中常用的函数

本文整理了笔者在学习pytorch中经常遇到的一些函数,本篇博客会不断进行更新,并且会加上自己使用背景和使用经验。

1. torch.max()函数

笔者最近在学习目标检测的相关知识,无论是在计算多个bounding box之间的IOU还是确定bounding box的类别信息的时候,都会用到torch.max()函数。torch.max()可以得到一个tensor某个维度的最大值,可以的得到两个tensor之间对应元素之间的最大值。

这个函数的签名是这样的:

torch.max(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)

1.1 单个tensor

情况1:如果输入tensor是一维

    # 输入一维的tensor
    a = torch.tensor([9,4,6,3,1,2,5,9])
    print(torch.max(a))

输出结果就是这一维数据中,最大的元素

    tensor(9)

情况2:如果输入tensor是二维

这时候就涉及到要比较哪个维度上的数据,是按行比较呢?还是按列比较呢?还是就是求这个tensor中最大的哪个元素?

(1)按列比较

    # 求每一列最大的元素
    b = torch.tensor([[4,8,2],[2,9,2]])
    max_value, max_index = torch.max(b,dim=0)
    print(max_value)
    print(max_index)

max_value表示得到的最终数据,max_index表示得到最大元素的每一列中索引。最终的数据结果是这样的:

    # max_value: 表示每一列的最大值
    tensor([4, 9, 2])

    # max_index: 表示最大值在每一列中的索引
    tensor([0, 1, 1])

我们发现,最终的结果相对于输入来说,结果的维度比输入的少了一个维度,这是因为我们指定了要比较哪个维度上面的数据。如果通过一幅图片显示torch.max()函数的效果就是这样:

(2)按行比较

    # 求每一列最大的元素
    b = torch.tensor([[4,8,2],[2,9,2]])
    max_value, max_index = torch.max(b,dim=1)
    print(max_value)
    print(max_index)

最终的结果就是这样的

    # max_value: 表示每一行的最大值
    tensor([8, 9])

    # max_index: 表示最大值在每一行中的索引
    tensor([1, 1])

(3)求整个tensor中最大的元素

    # 求一个tensor中最大的元素
    b = torch.tensor([[4,8,2],[2,9,2]])
    
    # 输出结果
    print(torch.max(b))
    
    # tensor([9])

 

1.2 两个tensor

torch.max()还可以比较两个tensor之间的最大值

情况1:单元素和另外一个tensor进行比较

    c = torch.tensor([5])
    d = torch.tensor([1,2,3])
    print(torch.max(c,d))

输出的结果是:

    tensor([5, 5, 5])

 

情况2:两个不同tensor之间进行比较

这种情况下,必须要求两个tensor的尺寸是一样的,不然会报错

    c = torch.tensor([[1,2,3],[2,3,4]])
    d = torch.tensor([[4,3,1],[3,6,3]])
    print(torch.max(c,d))

输出结果是这样:

    # 输出结果
    tensor([[4, 3, 3],
            [3, 6, 4]])

 

2. torch.argmax()函数

torch.argmax()函数和torch.max()函数功能差不多,只不多前者只会返回一个tensor最大元素的索引,并不会返回这个最大元素是什么。torch.max()函数不仅会返回最大值的索引,而且还会返回这个最大的元素是什么。

max()函数的函数签名是这样的:

    torch.argmax(input, dim=None, keepdim=False) -> Tensor

 其中dim的含义是这样的:the dimension to reduce(参考了这篇博客)

情况1:返回一个tensor中最大元素的索引

    # 返回一个tensor最大值的索引
    a = torch.tensor([[1,2,3],[2,3,4]])
    print(torch.argmax(b))

    # 输出值
    # tensor(5)

情况2:返回指定维度的最大元素索引

    # 将第0维的数据进行比较
    b = torch.tensor([[1,2,3],[2,3,4]])
    print(torch.argmax(b,dim=0)

    # 输入结果
    # tensor([1, 1, 1])


    # 将第1维数据进行比较
    print(torch.argmax(b,dim=1))

    # 输出结果
    # tensor([2,2])

 

3. torch.nonzero()函数

torch.nonzero()函数有很多功能,比如能够配合mask筛选出你想要的数据。比如在目标检测求confidence小于threshold的bbox,或者只选择某一类别的bbox,都可以使用nonzero()函数进行筛选。

torch.nonzero()函数签名如下:

    torch.nonzero(input, out=None) -> Tensor

这个函数的主要就是找到tensor中所有不为零元素的索引。函数的返回值是一个z*n的tensor。z表示不为零的元素的个数,n表示输入tensor的维度。

    # 打印a中所有不为0的元素的索引
    a = torch.tensor([[1,2,3],[1,2,0]])
    index = torch.nonzero(a)
    print(index)

输出结果是这样的:

    tensor([[0, 0],
            [0, 1],
            [0, 2],
            [1, 0],
            [1, 1]])

但是,我目前主要用的就是确定某一列中不含0的索引,然后再将其对应的某些行全部过滤出来。

    # 第2列中有一个非零数据,我们想把非零列所对应的行过滤出来
    a = torch.tensor([[1,2,3],[1,2,0]])
    index = torch.nonzero(a[:,2])
    print(a[index.squeeze()])

    # 输出效果
    tensor([1, 2, 3])

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章