torch.Tensor.masked_scatter_
- Tensor.masked_scatter_(mask, source)
-
从
source
中复制元素到self
张量,在mask
为True的位置。具体来说,从source
开始按顺序一个接一个地将元素复制到self
中,直到遇到mask
中的每个True值为止。mask
的形状必须与底层张量的形状可广播。此外,source
应至少包含与mask
中1的数量相等的元素。- 参数
-
-
mask (BoolTensor) – 布尔类型的掩码
-
source (Tensor) – 需要复制的源张量
-
注意
mask
操作的对象是self
张量,而不是给定的source
张量。示例
>>> self = torch.tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]) >>> mask = torch.tensor([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], dtype=torch.bool) >>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) >>> self.masked_scatter_(mask, source) tensor([[0, 0, 0, 0, 1], [2, 3, 0, 4, 5]])