torch.topk

torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None)

沿指定维度返回给定 input 张量中的 k 个最大元素。

如果没有提供dim,则默认选择输入的最后一个维度。

如果 largestFalse,则返回 k 个最小的元素。

返回一个命名元组 (values, indices),其中包含输入张量在给定维度 dim 上每行中最大的 k 个元素的值和索引。

当布尔选项 sortedTrue 时,会确保返回的 k 元素已经排序。

参数
  • input (Tensor) – 需要输入的张量。

  • k (int) – “top-k” 中的 k

  • dim (int, 可选) – 需要进行排序的维度

  • largest (bool, 可选) – 控制是否返回最大或最小的元素

  • sorted (bool, 可选) – 是否按排序顺序返回元素

关键字参数

out (元组, 可选) – 可选的输出元组 (Tensor, LongTensor),用作输出缓冲区

示例:

>>> x = torch.arange(1., 6.)
>>> x
tensor([ 1.,  2.,  3.,  4.,  5.])
>>> torch.topk(x, 3)
torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2]))
本页目录