torch.func.jacfwd
- torch.func.jacfwd(func, argnums=0, has_aux=False, *, randomness='error')
-
使用前向模式自动微分,计算
func
相对于索引argnum
处参数的雅可比矩阵。- 参数
- 返回值
-
返回一个函数,该函数接受与
func
相同的输入,并返回func
相对于argnums
参数的雅可比矩阵。如果has_aux
为 True,则返回的函数将返回一个包含雅可比矩阵和辅助对象的元组(jacobian, aux)
,其中jacobian
是雅可比矩阵,而aux
是由func
返回的辅助对象。
注意
你可能会遇到此 API 因为 “操作符 X 的前向模式自动微分未实现” 而产生的错误。如果是这样,请提交一个 bug 报告,我们将优先处理。另一种选择是使用
jacrev()
,它支持更多的操作符。使用逐点的一元操作的基本用法将得到一个对角矩阵作为雅可比矩阵
>>> from torch.func import jacfwd >>> x = torch.randn(5) >>> jacobian = jacfwd(torch.sin)(x) >>> expected = torch.diag(torch.cos(x)) >>> assert torch.allclose(jacobian, expected)
jacfwd()
可以与 vmap 结合使用来生成批量雅可比矩阵。>>> from torch.func import jacfwd, vmap >>> x = torch.randn(64, 5) >>> jacobian = vmap(jacfwd(torch.sin))(x) >>> assert jacobian.shape == (64, 5, 5)
如果你想同时计算函数的输出和雅可比矩阵,可以使用
has_aux
标志将输出作为辅助对象返回。>>> from torch.func import jacfwd >>> x = torch.randn(5) >>> >>> def f(x): >>> return x.sin() >>> >>> def g(x): >>> result = f(x) >>> return result, result >>> >>> jacobian_f, f_x = jacfwd(g, has_aux=True)(x) >>> assert torch.allclose(f_x, f(x))
此外,
jacrev()
可以与自身或其他jacrev()
组合,来生成 Hessian 矩阵。>>> from torch.func import jacfwd, jacrev >>> def f(x): >>> return x.sin().sum() >>> >>> x = torch.randn(5) >>> hessian = jacfwd(jacrev(f))(x) >>> assert torch.allclose(hessian, torch.diag(-x.sin()))
默认情况下,
jacfwd()
会根据第一个输入计算雅可比矩阵。但是,可以通过设置argnums
参数来针对不同的输入计算雅可比矩阵:>>> from torch.func import jacfwd >>> def f(x, y): >>> return x + y ** 2 >>> >>> x, y = torch.randn(5), torch.randn(5) >>> jacobian = jacfwd(f, argnums=1)(x, y) >>> expected = torch.diag(2 * y) >>> assert torch.allclose(jacobian, expected)
此外,将元组传递给
argnums
可以计算多个参数的雅可比矩阵。>>> from torch.func import jacfwd >>> def f(x, y): >>> return x + y ** 2 >>> >>> x, y = torch.randn(5), torch.randn(5) >>> jacobian = jacfwd(f, argnums=(0, 1))(x, y) >>> expectedX = torch.diag(torch.ones_like(x)) >>> expectedY = torch.diag(2 * y) >>> assert torch.allclose(jacobian[0], expectedX) >>> assert torch.allclose(jacobian[1], expectedY)