torch使用bool類型做檢索

一、背景

在使用torch的時候,可以通過bool類型對數組進行檢索操作。傳統的list或者dict都是使用下標和關鍵字檢索。而在torch中可以使用bool類型進行檢索,它的的目標主要是以下功能:

  • 替換torch中的某個值

二、使用

torch在bool檢索的情況下就是將爲檢索位置爲True的地方用另一個數據進行替換。

import torch

x = torch.Tensor([1, 2, 3, 4, 5])
# print(x)
noise_labels = torch.randint(len(x), x.shape)
print(noise_labels)
labels = x.clone()
print(labels)
probability_matrix = torch.full(x.shape, 0.15)
# print(probability_matrix)
masked_indices = torch.bernoulli(probability_matrix).bool()
print(masked_indices)
labels[masked_indices] = noise_labels[masked_indices]  # 將True的部分進行修改
print(labels)

# output:
"""
masked_indices第四個位置爲True,因此修改labels中第四個位置,由於噪聲數據第四個的位置是1,因此labels中的數據爲1
tensor([3, 2, 0, 1, 1])
tensor([1., 2., 3., 4., 5.])
tensor([False, False, False, False,  True])
tensor([1., 2., 3., 4., 1.])
"""


import torch

x = torch.Tensor([1, 2, 3, 4, 5])
# print(x)
noise_labels = torch.randint(len(x)+1999, x.shape)
print(noise_labels)
labels = x.clone()
print(labels)
probability_matrix = torch.full(x.shape, 0.15)
# print(probability_matrix)
masked_indices = torch.bernoulli(probability_matrix).bool()
print(masked_indices)
labels[masked_indices] = noise_labels[masked_indices]
print(labels)

# output:
"""
tensor([1516,  408,  274,  426,  126])
tensor([1., 2., 3., 4., 5.])
tensor([False, False, False,  True, False])
tensor([  1.,   2.,   3., 426.,   5.])
"""
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章