torch.Tensor.masked_scatter_

Tensor.masked_scatter_(mask, source)

source中复制元素到self张量,在mask为True的位置。具体来说,从source开始按顺序一个接一个地将元素复制到self中,直到遇到mask中的每个True值为止。mask的形状必须与底层张量的形状可广播。此外,source应至少包含与mask中1的数量相等的元素。

参数
  • mask (BoolTensor) – 布尔类型的掩码

  • source (Tensor) – 需要复制的源张量

注意

mask 操作的对象是 self 张量,而不是给定的 source 张量。

示例

>>> self = torch.tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]])
>>> mask = torch.tensor([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], dtype=torch.bool)
>>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
>>> self.masked_scatter_(mask, source)
tensor([[0, 0, 0, 0, 1],
        [2, 3, 0, 4, 5]])
本页目录