torch.quantile
- torch.quantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) → Tensor
-
沿维度
dim计算input张量中每一行的第q分位数。为了计算分位数,我们将 q 映射到 [0, 1] 范围内的索引 [0, n],以确定排序输入中分位数的位置。如果分位数位于两个数据点
a和b(它们在排序顺序中的索引分别为i和j)之间,则根据给定的interpolation方法进行如下计算:-
linear:a + (b - a) * fraction,其中fraction是计算出的分位数索引的小数部分。 -
lower:a -
higher:b. -
nearest: 在a和b中,选择计算出的分位数索引(对于0.5的小数部分进行下舍入)更接近的那个。 -
midpoint:(a + b) / 2.
如果
q是一个一维张量,那么输出的第一个维度将代表分位数,并且其大小与q的大小相同。其余的维度则由缩减操作决定。注意
默认情况下,
dim为None,这意味着在计算前会将input张量展平。- 参数
- 关键字参数
示例:
>>> a = torch.randn(2, 3) >>> a tensor([[ 0.0795, -1.2117, 0.9765], [ 1.1707, 0.6706, 0.4884]]) >>> q = torch.tensor([0.25, 0.5, 0.75]) >>> torch.quantile(a, q, dim=1, keepdim=True) tensor([[[-0.5661], [ 0.5795]], [[ 0.0795], [ 0.6706]], [[ 0.5280], [ 0.9206]]]) >>> torch.quantile(a, q, dim=1, keepdim=True).shape torch.Size([3, 2, 1]) >>> a = torch.arange(4.) >>> a tensor([0., 1., 2., 3.]) >>> torch.quantile(a, 0.6, interpolation='linear') tensor(1.8000) >>> torch.quantile(a, 0.6, interpolation='lower') tensor(1.) >>> torch.quantile(a, 0.6, interpolation='higher') tensor(2.) >>> torch.quantile(a, 0.6, interpolation='midpoint') tensor(1.5000) >>> torch.quantile(a, 0.6, interpolation='nearest') tensor(2.) >>> torch.quantile(a, 0.4, interpolation='nearest') tensor(1.)
-