torch.func.hessian

torch.func.hessian(func, argnums=0)

通过正向和逆向策略计算func在索引argnum处参数的Hessian矩阵。

前向过反向策略(组合 jacfwd(jacrev(func)))是实现良好性能的良好默认选择。也可以通过其他方式组合 jacfwd()jacrev() 来计算黑塞矩阵,例如 jacfwd(jacfwd(func))jacrev(jacrev(func))

参数
  • func(函数)– 一个接受一个或多个参数的Python函数,其中至少有一个参数是张量,并返回一个或多个张量。

  • argnums (intTuple[int]) – 可选,用于指定要计算海森矩阵的参数。可以是整数或整数元组。默认值:0。

返回值

返回一个函数,该函数接受与 func 相同的输入,并返回 funcargnums 处相对于参数的海森矩阵。

注意

你可能会遇到此 API 因为“操作符 X 的前向模式自动微分未实现”而报错的情况。如果是这样,请提交一个 bug 报告,我们会优先处理。另一种方法是使用 jacrev(jacrev(func)),它具有更广泛的操作符支持。

对于从 R^N 到 R^1 的函数,其基本应用会产生一个 N x N 的黑塞矩阵。

>>> from torch.func import hessian
>>> def f(x):
>>>   return x.sin().sum()
>>>
>>> x = torch.randn(5)
>>> hess = hessian(f)(x)  # equivalent to jacfwd(jacrev(f))(x)
>>> assert torch.allclose(hess, torch.diag(-x.sin()))
本页目录