本文整理了笔者在学习pytorch中经常遇到的一些函数,本篇博客会不断进行更新,并且会加上自己使用背景和使用经验。
1. torch.max()函数
笔者最近在学习目标检测的相关知识,无论是在计算多个bounding box之间的IOU还是确定bounding box的类别信息的时候,都会用到torch.max()函数。torch.max()可以得到一个tensor某个维度的最大值,可以的得到两个tensor之间对应元素之间的最大值。
这个函数的签名是这样的:
torch.max(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)
1.1 单个tensor
情况1:如果输入tensor是一维:
# 输入一维的tensor
a = torch.tensor([9,4,6,3,1,2,5,9])
print(torch.max(a))
输出结果就是这一维数据中,最大的元素
tensor(9)
情况2:如果输入tensor是二维:
这时候就涉及到要比较哪个维度上的数据,是按行比较呢?还是按列比较呢?还是就是求这个tensor中最大的哪个元素?
(1)按列比较
# 求每一列最大的元素
b = torch.tensor([[4,8,2],[2,9,2]])
max_value, max_index = torch.max(b,dim=0)
print(max_value)
print(max_index)
max_value表示得到的最终数据,max_index表示得到最大元素的每一列中索引。最终的数据结果是这样的:
# max_value: 表示每一列的最大值
tensor([4, 9, 2])
# max_index: 表示最大值在每一列中的索引
tensor([0, 1, 1])
我们发现,最终的结果相对于输入来说,结果的维度比输入的少了一个维度,这是因为我们指定了要比较哪个维度上面的数据。如果通过一幅图片显示torch.max()函数的效果就是这样:
(2)按行比较
# 求每一列最大的元素
b = torch.tensor([[4,8,2],[2,9,2]])
max_value, max_index = torch.max(b,dim=1)
print(max_value)
print(max_index)
最终的结果就是这样的
# max_value: 表示每一行的最大值
tensor([8, 9])
# max_index: 表示最大值在每一行中的索引
tensor([1, 1])
(3)求整个tensor中最大的元素
# 求一个tensor中最大的元素
b = torch.tensor([[4,8,2],[2,9,2]])
# 输出结果
print(torch.max(b))
# tensor([9])
1.2 两个tensor
torch.max()还可以比较两个tensor之间的最大值
情况1:单元素和另外一个tensor进行比较
c = torch.tensor([5])
d = torch.tensor([1,2,3])
print(torch.max(c,d))
输出的结果是:
tensor([5, 5, 5])
情况2:两个不同tensor之间进行比较
这种情况下,必须要求两个tensor的尺寸是一样的,不然会报错
c = torch.tensor([[1,2,3],[2,3,4]])
d = torch.tensor([[4,3,1],[3,6,3]])
print(torch.max(c,d))
输出结果是这样:
# 输出结果
tensor([[4, 3, 3],
[3, 6, 4]])
2. torch.argmax()函数
torch.argmax()函数和torch.max()函数功能差不多,只不多前者只会返回一个tensor最大元素的索引,并不会返回这个最大元素是什么。torch.max()函数不仅会返回最大值的索引,而且还会返回这个最大的元素是什么。
max()函数的函数签名是这样的:
torch.argmax(input, dim=None, keepdim=False) -> Tensor
其中dim的含义是这样的:the dimension to reduce(参考了这篇博客)
情况1:返回一个tensor中最大元素的索引
# 返回一个tensor最大值的索引
a = torch.tensor([[1,2,3],[2,3,4]])
print(torch.argmax(b))
# 输出值
# tensor(5)
情况2:返回指定维度的最大元素索引
# 将第0维的数据进行比较
b = torch.tensor([[1,2,3],[2,3,4]])
print(torch.argmax(b,dim=0)
# 输入结果
# tensor([1, 1, 1])
# 将第1维数据进行比较
print(torch.argmax(b,dim=1))
# 输出结果
# tensor([2,2])
3. torch.nonzero()函数
torch.nonzero()函数有很多功能,比如能够配合mask筛选出你想要的数据。比如在目标检测求confidence小于threshold的bbox,或者只选择某一类别的bbox,都可以使用nonzero()函数进行筛选。
torch.nonzero()函数签名如下:
torch.nonzero(input, out=None) -> Tensor
这个函数的主要就是找到tensor中所有不为零元素的索引。函数的返回值是一个z*n的tensor。z表示不为零的元素的个数,n表示输入tensor的维度。
# 打印a中所有不为0的元素的索引
a = torch.tensor([[1,2,3],[1,2,0]])
index = torch.nonzero(a)
print(index)
输出结果是这样的:
tensor([[0, 0],
[0, 1],
[0, 2],
[1, 0],
[1, 1]])
但是,我目前主要用的就是确定某一列中不含0的索引,然后再将其对应的某些行全部过滤出来。
# 第2列中有一个非零数据,我们想把非零列所对应的行过滤出来
a = torch.tensor([[1,2,3],[1,2,0]])
index = torch.nonzero(a[:,2])
print(a[index.squeeze()])
# 输出效果
tensor([1, 2, 3])