torch.topk
- torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None)
-
沿指定维度返回给定
input
张量中的k
个最大元素。如果没有提供
dim
,则默认选择输入的最后一个维度。如果
largest
为False
,则返回 k 个最小的元素。返回一个命名元组 (values, indices),其中包含输入张量在给定维度 dim 上每行中最大的 k 个元素的值和索引。
当布尔选项
sorted
为True
时,会确保返回的 k 元素已经排序。- 参数
- 关键字参数
-
out (元组, 可选) – 可选的输出元组 (Tensor, LongTensor),用作输出缓冲区
示例:
>>> x = torch.arange(1., 6.) >>> x tensor([ 1., 2., 3., 4., 5.]) >>> torch.topk(x, 3) torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2]))