torch.kthvalue

torch.kthvalue(input, k, dim=None, keepdim=False, *, out=None)

返回一个命名元组 (values, indices),其中 valuesinput 张量在给定维度 dim 上每一行的第 k 小元素。而 indices 则是每个元素的位置索引。

如果没有提供dim,则默认选择输入的最后一个维度。

如果 keepdimTrue,那么 valuesindices 张量与 input 张量大小相同,只是在维度 dim 上它们的大小为 1。否则,dim 将被挤压(参见torch.squeeze()),导致 valuesindices 张量比 input 张量少一个维度。

注意

input是一个CUDA张量,并且存在多个有效的k值时,此函数可能会非确定性地返回任意一个的有效indices

参数
  • input (Tensor) – 需要输入的张量。

  • k (int) - 表示第 k 个最小元素的 k 值

  • dim (int, 可选) – 查找第 k 个值的维度

  • keepdim (bool) – 是否在输出张量中保留dim维度。

关键字参数

out (元组, 可选) – 输出的 (Tensor, LongTensor) 元组可以选择性地提供,用作输出缓冲区

示例:

>>> x = torch.arange(1., 6.)
>>> x
tensor([ 1.,  2.,  3.,  4.,  5.])
>>> torch.kthvalue(x, 4)
torch.return_types.kthvalue(values=tensor(4.), indices=tensor(3))

>>> x=torch.arange(1.,7.).resize_(2,3)
>>> x
tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.]])
>>> torch.kthvalue(x, 2, 0, True)
torch.return_types.kthvalue(values=tensor([[4., 5., 6.]]), indices=tensor([[1, 1, 1]]))
本页目录