torch.take

torch.take(input, index) Tensor

根据给定的索引返回一个新的张量,该张量包含中的相应元素。尽管输入张量实际上是多维的,但在处理时将其视为一维的。最终的结果将与索引的形状相同。

参数
  • input (Tensor) – 需要输入的张量。

  • index (LongTensor) – 张量中的索引值

示例:

>>> src = torch.tensor([[4, 3, 5],
...                     [6, 7, 8]])
>>> torch.take(src, torch.tensor([0, 2, 5]))
tensor([ 4,  5,  8])
本页目录