torch.gather

torch.gather(input, dim, index, *, sparse_grad=False, out=None) Tensor

沿由dim指定的轴收集值。

对于一个3D张量,其输出如下指定:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

inputindex 必须具有相同的维度数量,并且对于所有不等于 dim 的维度 d,需要满足 index.size(d) <= input.size(d)out 将与 index 具有相同的形状。需要注意的是,inputindex 不能相互广播。

参数
  • input (Tensor) – 来源张量

  • dim (int) - 需要进行索引的轴

  • index (LongTensor) - 收集元素的索引

关键字参数
  • sparse_grad (bool, optional) – 如果为 True,则 input 的梯度将是一个稀疏张量。

  • out (Tensor, 可选) – 输出的目标张量

示例:

>>> t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1,  1],
        [ 4,  3]])
本页目录