torch.func.hessian
- torch.func.hessian(func, argnums=0)
-
通过正向和逆向策略计算
func
在索引argnum
处参数的Hessian矩阵。前向过反向策略(组合
jacfwd(jacrev(func))
)是实现良好性能的良好默认选择。也可以通过其他方式组合jacfwd()
和jacrev()
来计算黑塞矩阵,例如jacfwd(jacfwd(func))
或jacrev(jacrev(func))
。- 参数
- 返回值
-
返回一个函数,该函数接受与
func
相同的输入,并返回func
在argnums
处相对于参数的海森矩阵。
注意
你可能会遇到此 API 因为“操作符 X 的前向模式自动微分未实现”而报错的情况。如果是这样,请提交一个 bug 报告,我们会优先处理。另一种方法是使用
jacrev(jacrev(func))
,它具有更广泛的操作符支持。对于从 R^N 到 R^1 的函数,其基本应用会产生一个 N x N 的黑塞矩阵。
>>> from torch.func import hessian >>> def f(x): >>> return x.sin().sum() >>> >>> x = torch.randn(5) >>> hess = hessian(f)(x) # equivalent to jacfwd(jacrev(f))(x) >>> assert torch.allclose(hess, torch.diag(-x.sin()))