torch.autograd.gradcheck

torch.autograd.gradcheck.gradcheck(func, inputs, *, eps=1e-06, atol=1e-05, rtol=0.001, raise_exception=True, nondet_tol=0.0, check_undefined_grad=True, check_grad_dtypes=False, check_batched_grad=False, check_batched_forward_grad=False, check_forward_ad=False, check_backward_ad=True, fast_mode=False, masked=None)[源代码]

将通过小的有限差分计算得到的梯度,与inputs中具有requires_grad=True且为浮点或复数类型的张量的解析梯度进行对比。

数值梯度和解析梯度的检查使用 allclose() 函数。

对于大多数我们为了优化目的而考虑的复杂函数,Jacobian 概念并不存在。相反,gradcheck 会验证 Wirtinger 和共轭 Wirtinger 导数的数值和解析值是否一致。由于梯度计算假设整个函数具有实数值输出,因此我们需要以特殊方式处理具有复数值输出的函数。对于这些函数,gradcheck 应用于两个实数值函数:一个对应于取第一个函数的复数输出的实部,另一个对应于取第二个函数的复数输出的虚部。更多详情,请参阅 复数 Autograd

注意

默认值是为双精度的 input 设计的。如果 input 的精度较低(例如 FloatTensor),此检查可能会失败。

注意

在非可微点上进行评估时,梯度检查可能会失败,因为通过有限差分计算出的数值梯度可能与解析方法得出的结果不同(不一定表示其中一方是错误的)。更多背景信息,请参阅非可微函数的梯度

警告

如果 input 中的任何受检张量存在内存重叠(即不同的索引指向相同的内存地址,例如通过 torch.expand() 引起的情况),此检查可能会失败。因为在这些索引处由点扰动计算出的数值梯度会改变所有共享相同内存地址的其他索引处的值。

参数
  • func函数)– 一个接受张量输入并返回张量或张量元组的 Python 函数

  • inputs (元组 of TensorTensor) – 函数的输入参数

  • eps (float, 可选) – 有限差分计算中的扰动值

  • atol (float, 可选) – 绝对 tolerance

    改为更自然的表达:

    atol (float, 可选) – 绝对容差值

  • rtol (float, 可选) – 相对容忍度

  • raise_exception (bool, 可选) – 表示是否在检查失败时抛出异常。如果抛出异常,会提供更多关于失败原因的详细信息,这在调试梯度检查时非常有用。

  • nondet_tol (float, 可选) – 非确定性的容忍度。在使用相同的输入进行微分时,结果必须完全一致(默认值为 0.0)或在此容忍度范围内。

  • check_undefined_grad (bool, 可选) – 如果为True,则检查未定义的输出梯度是否被支持并视为零值,适用于Tensor 输出。

  • check_batched_grad (bool, optional) – 如果为 True,则检查是否可以使用原型 vmap 支持来计算批处理梯度。默认值为 False。

  • check_batched_forward_grad (bool, optional) – 如果设置为 True,则检查是否可以使用正向自动微分和原型 vmap 支持来计算批量前向梯度。默认值为 False

  • check_forward_ad (bool, optional) – 如果设置为 True,则验证前向模式自动微分计算的梯度是否与数值梯度一致。默认值为 False

  • check_backward_ad (bool, optional) – 如果设置为 False,则不执行任何依赖于反向模式自动微分的检查。默认值为 True

  • fast_mode (bool, 可选) – 快速模式目前仅适用于 R 到 R 的函数的 gradcheck 和 gradgradcheck。如果输入和输出都不是复数,则会运行一个更快版本的 gradcheck,该版本不再计算完整的雅可比矩阵;否则,将回退到慢速实现。

  • masked (bool, 可选) – 如果设置为 True,将忽略稀疏张量中未指定元素的梯度。默认值为 False

返回值

如果所有差异都满足 allclose 条件,则返回 True

返回类型

bool

本页目录