一、背景
在使用torch的時候,可以通過bool類型對數組進行檢索操作。傳統的list或者dict都是使用下標和關鍵字檢索。而在torch中可以使用bool類型進行檢索,它的的目標主要是以下功能:
- 替換torch中的某個值
二、使用
torch在bool檢索的情況下就是將爲檢索位置爲True的地方用另一個數據進行替換。
import torch
x = torch.Tensor([1, 2, 3, 4, 5])
# print(x)
noise_labels = torch.randint(len(x), x.shape)
print(noise_labels)
labels = x.clone()
print(labels)
probability_matrix = torch.full(x.shape, 0.15)
# print(probability_matrix)
masked_indices = torch.bernoulli(probability_matrix).bool()
print(masked_indices)
labels[masked_indices] = noise_labels[masked_indices] # 將True的部分進行修改
print(labels)
# output:
"""
masked_indices第四個位置爲True,因此修改labels中第四個位置,由於噪聲數據第四個的位置是1,因此labels中的數據爲1
tensor([3, 2, 0, 1, 1])
tensor([1., 2., 3., 4., 5.])
tensor([False, False, False, False, True])
tensor([1., 2., 3., 4., 1.])
"""
import torch
x = torch.Tensor([1, 2, 3, 4, 5])
# print(x)
noise_labels = torch.randint(len(x)+1999, x.shape)
print(noise_labels)
labels = x.clone()
print(labels)
probability_matrix = torch.full(x.shape, 0.15)
# print(probability_matrix)
masked_indices = torch.bernoulli(probability_matrix).bool()
print(masked_indices)
labels[masked_indices] = noise_labels[masked_indices]
print(labels)
# output:
"""
tensor([1516, 408, 274, 426, 126])
tensor([1., 2., 3., 4., 5.])
tensor([False, False, False, True, False])
tensor([ 1., 2., 3., 426., 5.])
"""