torch.kthvalue
- torch.kthvalue(input, k, dim=None, keepdim=False, *, out=None)
-
返回一个命名元组
(values, indices),其中values是input张量在给定维度dim上每一行的第k小元素。而indices则是每个元素的位置索引。如果没有提供
dim,则默认选择输入的最后一个维度。如果
keepdim为True,那么values和indices张量与input张量大小相同,只是在维度dim上它们的大小为 1。否则,dim将被挤压(参见torch.squeeze()),导致values和indices张量比input张量少一个维度。注意
当
input是一个CUDA张量,并且存在多个有效的k值时,此函数可能会非确定性地返回任意一个的有效indices。- 参数
- 关键字参数
-
out (元组, 可选) – 输出的 (Tensor, LongTensor) 元组可以选择性地提供,用作输出缓冲区
示例:
>>> x = torch.arange(1., 6.) >>> x tensor([ 1., 2., 3., 4., 5.]) >>> torch.kthvalue(x, 4) torch.return_types.kthvalue(values=tensor(4.), indices=tensor(3)) >>> x=torch.arange(1.,7.).resize_(2,3) >>> x tensor([[ 1., 2., 3.], [ 4., 5., 6.]]) >>> torch.kthvalue(x, 2, 0, True) torch.return_types.kthvalue(values=tensor([[4., 5., 6.]]), indices=tensor([[1, 1, 1]]))