Pytorch 替換tensor中大於某個值的所有元素

Pytorch: tensor中大於某個值的所有元素, 如置0

a = torch.rand((2, 3))
# tensor([[0.2620, 0.4850, 0.5924],
#         [0.4152, 0.0475, 0.5491]])

zero = torch.zeros_like(a)
# tensor([[0., 0., 0.],
#         [0., 0., 0.]])

one = torch.ones_like(a)
# tensor([[1., 1., 1.],
#         [1., 1., 1.]])

# a中大於0.5的用zero(0)替換,否則a替換,即不變
a = torch.where(a > 0.5, zero, a))
# tensor([[0.2620, 0.4850, 0.0000],
#         [0.4152, 0.0475, 0.0000]])

# a中大於0.5的用one(1)替換,否則a替換,即不變
print(torch.where(a > 0.5, one, a))
# tensor([[0.2620, 0.4850, 1.0000],
#         [0.4152, 0.0475, 1.0000]])

Numpy: 矩陣中大於某個值的所有元素, 如置0

# 矩陣a中大於Threshold(閾值)的部分置0
a[a > Threshold] = 0
# 矩陣a中小魚Threshold(閾值)的部分置0
a[a < Threshold] = 0

 

 
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章