Pytorch: tensor中大于某个值的所有元素, 如置0
a = torch.rand((2, 3))
# tensor([[0.2620, 0.4850, 0.5924],
# [0.4152, 0.0475, 0.5491]])
zero = torch.zeros_like(a)
# tensor([[0., 0., 0.],
# [0., 0., 0.]])
one = torch.ones_like(a)
# tensor([[1., 1., 1.],
# [1., 1., 1.]])
# a中大于0.5的用zero(0)替换,否则a替换,即不变
a = torch.where(a > 0.5, zero, a))
# tensor([[0.2620, 0.4850, 0.0000],
# [0.4152, 0.0475, 0.0000]])
# a中大于0.5的用one(1)替换,否则a替换,即不变
print(torch.where(a > 0.5, one, a))
# tensor([[0.2620, 0.4850, 1.0000],
# [0.4152, 0.0475, 1.0000]])
Numpy: 矩阵中大于某个值的所有元素, 如置0
# 矩阵a中大于Threshold(阈值)的部分置0
a[a > Threshold] = 0
# 矩阵a中小鱼Threshold(阈值)的部分置0
a[a < Threshold] = 0