torch.nanmean
- torch.nanmean(input, dim=None, keepdim=False, *, dtype=None, out=None) → Tensor
-
计算指定维度上所有非NaN元素的平均值。输入数据必须为浮点数或复数。
当
input
张量中没有NaN值时,此函数与torch.mean()
完全相同。在存在NaN的情况下,torch.mean()
会将NaN传播到输出结果,而torch.nanmean()
则会忽略NaN值(torch.nanmean(a)等同于torch.mean(a[~a.isnan()]))。如果
keepdim
是True
,则输出张量与输入张量大小相同,除了在dim
指定的维度上其大小为 1。否则,dim
维度会被挤压(参见torch.squeeze()
),导致输出张量比输入少 1 (或len(dim)
)个维度。- 参数
- 关键字参数
-
-
dtype (
torch.dtype
, 可选) – 返回的张量的数据类型。如果指定了dtype
,则在执行操作之前将输入张量转换为此数据类型,以防止数据溢出。默认值:None。 -
out (Tensor, 可选) – 指定输出张量。
-
参见
torch.mean()
计算平均值并传播 NaN。示例:
>>> x = torch.tensor([[torch.nan, 1, 2], [1, 2, 3]]) >>> x.mean() tensor(nan) >>> x.nanmean() tensor(1.8000) >>> x.mean(dim=0) tensor([ nan, 1.5000, 2.5000]) >>> x.nanmean(dim=0) tensor([1.0000, 1.5000, 2.5000]) # If all elements in the reduced dimensions are NaN then the result is NaN >>> torch.tensor([torch.nan]).nanmean() tensor(nan)