高斯负对数似然损失函数

classtorch.nn.GaussianNLLLoss(*, full=False, eps=1e-06, reduction='mean')[源代码]

高斯负对数似然损失函数。

目标被视作来自高斯分布的样本,其期望值和方差由神经网络预测。对于一个用target张量建模为具有期望值张量input和正方差张量var的高斯分布,损失函数定义为:

$\text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var}, \ \text{eps}\right)\right) + \frac{\left(\text{input} - \text{target}\right)^2} {\text{max}\left(\text{var}, \ \text{eps}\right)}\right) + \text{const.}$

其中,eps用于保证稳定性。默认情况下,损失函数中的常数项会被省略,除非full被设置为True。如果var的大小与input不同(基于同方差假设),那么var必须具有最终维度为1或少一个维度,同时其他所有尺寸保持一致,以确保正确的广播操作。

参数
  • full (bool, 可选) – 是否在损失计算中包含常数项。默认值: False

  • eps (float, 可选) – 用于稳定地将var 钳位的值(参见下方注释)。默认值:1e-6。

  • reduction (str, 可选) – 指定要应用于输出的缩减方式:'none' | 'mean' | 'sum''none': 不进行任何缩减,'mean': 输出是所有批次成员损失值的平均数,'sum': 输出是所有批次成员损失值的总和。默认值:'mean'

形状:
  • 输入: $(N, *)$$(*)$,其中 $*$ 表示任意数量的附加维度。

  • 目标: $(N, *)$$(*)$, 与输入的形状相同,或者与输入形状相同但有一个维度为1(以便进行广播操作)

  • Var: $(N, *)$$(*)$,与输入形状相同,或者有一个维度为1的形状与输入相同,或者比输入少一个维度(以允许广播操作)

  • 输出:如果 reduction'mean'(默认值)或 'sum',则为标量。 如果 reduction'none',则输出形状与输入相同,为 $(N, *)$

示例:
>>> loss = nn.GaussianNLLLoss()
>>> input = torch.randn(5, 2, requires_grad=True)
>>> target = torch.randn(5, 2)
>>> var = torch.ones(5, 2, requires_grad=True)  # heteroscedastic
>>> output = loss(input, target, var)
>>> output.backward()
>>> loss = nn.GaussianNLLLoss()
>>> input = torch.randn(5, 2, requires_grad=True)
>>> target = torch.randn(5, 2)
>>> var = torch.ones(5, 1, requires_grad=True)  # homoscedastic
>>> output = loss(input, target, var)
>>> output.backward()

注意

在自动微分中,var 的钳位操作会被忽略,因此它不会影响梯度。

参考

Nix, D. A. 和 Weigend, A. S., “估计目标概率分布的均值和方差”,收录于1994年IEEE国际神经网络会议(ICNN’94)论文集,美国佛罗里达州奥兰多,1994年,第1卷,页55-60,doi: 10.1109/ICNN.1994.374138。

本页目录