torch.quantile

torch.quantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) Tensor

沿维度dim计算input张量中每一行的第q分位数。

为了计算分位数,我们将 q 映射到 [0, 1] 范围内的索引 [0, n],以确定排序输入中分位数的位置。如果分位数位于两个数据点 ab(它们在排序顺序中的索引分别为 ij)之间,则根据给定的 interpolation 方法进行如下计算:

  • linear: a + (b - a) * fraction,其中fraction是计算出的分位数索引的小数部分。

  • lower: a

  • higher: b.

  • nearest: 在 ab 中,选择计算出的分位数索引(对于0.5的小数部分进行下舍入)更接近的那个。

  • midpoint: (a + b) / 2.

如果 q 是一个一维张量,那么输出的第一个维度将代表分位数,并且其大小与 q 的大小相同。其余的维度则由缩减操作决定。

注意

默认情况下,dimNone,这意味着在计算前会将 input 张量展平。

参数
  • input (Tensor) – 需要输入的张量。

  • q (floatTensor) – 范围在 [0, 1] 内的标量或一维张量。

  • dim (int) - 需要减少的维度。

  • keepdim (bool) – 是否在输出张量中保留dim维度。

关键字参数
  • interpolation (str) – 当所需的分位数位于两个数据点之间时,指定插值方法。可选的插值方法包括 linear, lower, higher, midpointnearest。默认为 linear

  • out (Tensor, 可选) – 指定输出张量。

示例:

>>> a = torch.randn(2, 3)
>>> a
tensor([[ 0.0795, -1.2117,  0.9765],
        [ 1.1707,  0.6706,  0.4884]])
>>> q = torch.tensor([0.25, 0.5, 0.75])
>>> torch.quantile(a, q, dim=1, keepdim=True)
tensor([[[-0.5661],
        [ 0.5795]],

        [[ 0.0795],
        [ 0.6706]],

        [[ 0.5280],
        [ 0.9206]]])
>>> torch.quantile(a, q, dim=1, keepdim=True).shape
torch.Size([3, 2, 1])
>>> a = torch.arange(4.)
>>> a
tensor([0., 1., 2., 3.])
>>> torch.quantile(a, 0.6, interpolation='linear')
tensor(1.8000)
>>> torch.quantile(a, 0.6, interpolation='lower')
tensor(1.)
>>> torch.quantile(a, 0.6, interpolation='higher')
tensor(2.)
>>> torch.quantile(a, 0.6, interpolation='midpoint')
tensor(1.5000)
>>> torch.quantile(a, 0.6, interpolation='nearest')
tensor(2.)
>>> torch.quantile(a, 0.4, interpolation='nearest')
tensor(1.)
本页目录