torch.Tensor.index_copy_

Tensor.index_copy_(dim, index, tensor) Tensor

根据index中给定的顺序选择索引,将tensor中的元素复制到self张量中。例如,如果dim == 0index[i] == j,那么tensor的第i行将被复制到self张量的第j行。

dim表示的维度中,其大小必须与index的长度(index 必须是一个向量)相同;同时所有其他维度必须与self匹配,否则将引发错误。

注意

如果 index 包含重复项,则 tensor 中的多个元素将被复制到 self 的同一个索引位置。由于结果取决于最后一个副本何时发生,因此是非确定性的。

参数
  • dim (int) - 需要进行索引的维度

  • index (LongTensor) - 用于从 tensor 中选取元素的索引

  • tensor (Tensor) – 包含需要复制的值的张量

示例:

>>> x = torch.zeros(5, 3)
>>> t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float)
>>> index = torch.tensor([0, 4, 2])
>>> x.index_copy_(0, index, t)
tensor([[ 1.,  2.,  3.],
        [ 0.,  0.,  0.],
        [ 7.,  8.,  9.],
        [ 0.,  0.,  0.],
        [ 4.,  5.,  6.]])
本页目录