masked_fill_() - masked_fill() - v1.5.0

masked_fill_() - masked_fill() - v1.5.0

torch.Tensor
https://pytorch.org/docs/stable/tensors.html

  • torch.Tensor.masked_fill (Python method, in torch.Tensor)
  • torch.Tensor.masked_fill_ (Python method, in torch.Tensor)

masked_fill_(mask, value) - 函數名後面加下劃線。in-place version 在 PyTorch 中是指當改變一個 tensor 的值的時候,不經過複製操作,而是直接在原來的內存上改變它的值,可以稱爲原地操作符。
masked_fill(mask, value) -> Tensor - 函數名後面沒有下劃線。out-of-place version 在 PyTorch 中是指當改變一個 tensor 的值的時候,經過複製操作,不是直接在原來的內存上改變它的值,而是修改複製的 tensor。

1. masked_fill_(mask, value)

Fills elements of self tensor with value where mask is True. The shape of mask must be broadcastable with the shape of the underlying tensor.
當對應位置的 mask 是 1,用 value 填充 self tensor 中的元素。

1.1 Parameters

mask (BoolTensor) – the boolean mask (元素是布爾值)
value (float) – the value to fill in with (用於填充的值)

2. masked_fill(mask, value) -> Tensor

Out-of-place version of torch.Tensor.masked_fill_()

3. example

3.1 masked_fill(mask, value) -> Tensor

(pt-1.4_py-3.6) yongqiang@yongqiang:~$ python
Python 3.6.10 |Anaconda, Inc.| (default, May  8 2020, 02:54:21)
[GCC 7.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> data = torch.randn(2, 3)
>>> data
tensor([[ 1.1389,  0.7854, -1.1975],
        [ 0.1931,  1.4460, -0.0749]])
>>>
>>> mask = torch.tensor([[True, False, True], [False, True, False]])
>>> mask
tensor([[ True, False,  True],
        [False,  True, False]])
>>>
>>> masked1 = data.masked_fill(mask, 999)
>>> masked1
tensor([[ 9.9900e+02,  7.8542e-01,  9.9900e+02],
        [ 1.9310e-01,  9.9900e+02, -7.4897e-02]])
>>>
>>> data
tensor([[ 1.1389,  0.7854, -1.1975],
        [ 0.1931,  1.4460, -0.0749]])
>>>
>>> exit()
(pt-1.4_py-3.6) yongqiang@yongqiang:~$

3.2 masked_fill_(mask, value)

(pt-1.4_py-3.6) yongqiang@yongqiang:~$ python
Python 3.6.10 |Anaconda, Inc.| (default, May  8 2020, 02:54:21)
[GCC 7.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> data = torch.randn(2, 3)
>>> data
tensor([[ 0.0718, -0.4983, -0.7344],
        [-2.0372, -1.6503,  1.6308]])
>>>
>>> mask = torch.tensor([[True, False, True], [False, True, False]])
>>> mask
tensor([[ True, False,  True],
        [False,  True, False]])
>>>
>>> masked1 = data.masked_fill_(mask, 999)
>>> masked1
tensor([[ 9.9900e+02, -4.9832e-01,  9.9900e+02],
        [-2.0372e+00,  9.9900e+02,  1.6308e+00]])
>>>
>>> data
tensor([[ 9.9900e+02, -4.9832e-01,  9.9900e+02],
        [-2.0372e+00,  9.9900e+02,  1.6308e+00]])
>>>
>>> exit()
(pt-1.4_py-3.6) yongqiang@yongqiang:~$

3.3 -np.inf

(pt-1.4_py-3.6) yongqiang@yongqiang:~$ python
Python 3.6.10 |Anaconda, Inc.| (default, May  8 2020, 02:54:21)
[GCC 7.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> data = torch.randn(2, 3)
>>> data
tensor([[ 0.3838, -0.8961,  0.4759],
        [ 0.4764, -0.2403,  0.4010]])
>>>
>>> mask = torch.tensor([[True, False, True], [False, True, False]])
>>> mask
tensor([[ True, False,  True],
        [False,  True, False]])
>>>
>>> masked1 = data.masked_fill(mask, 0)
>>> masked1
tensor([[ 0.0000, -0.8961,  0.0000],
        [ 0.4764,  0.0000,  0.4010]])
>>>
>>> data
tensor([[ 0.3838, -0.8961,  0.4759],
        [ 0.4764, -0.2403,  0.4010]])
>>>
>>> exit()
(pt-1.4_py-3.6) yongqiang@yongqiang:~$
(pt-1.4_py-3.6) yongqiang@yongqiang:~$ python
Python 3.6.10 |Anaconda, Inc.| (default, May  8 2020, 02:54:21)
[GCC 7.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import numpy as np
>>>
>>> data = torch.randn(2, 3)
>>> data
tensor([[5.2904e-02, 9.4895e-01, 2.6957e-01],
        [1.2166e-03, 1.2486e+00, 3.0534e+00]])
>>>
>>> mask = torch.tensor([[True, False, True], [False, True, False]])
>>> mask
tensor([[ True, False,  True],
        [False,  True, False]])
>>>
>>> masked1 = data.masked_fill(mask, -np.inf)
>>> masked1
tensor([[      -inf, 9.4895e-01,       -inf],
        [1.2166e-03,       -inf, 3.0534e+00]])
>>>
>>> data
tensor([[5.2904e-02, 9.4895e-01, 2.6957e-01],
        [1.2166e-03, 1.2486e+00, 3.0534e+00]])
>>>
>>> exit()
(pt-1.4_py-3.6) yongqiang@yongqiang:~$
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章