torch.func.vjp
- torch.func.vjp(func, *primals, has_aux=False)
-
表示向量-Jacobian 积,返回一个元组。该元组包含将
func
应用于primals
的结果以及一个函数。当给定cotangents
时,这个函数会计算func
关于primals
的反向模式 Jacobian,并将其与cotangents
相乘。- 参数
-
-
func (Callable) – 一个接受一个或多个参数的 Python 函数,必须返回一个或多个张量。
-
primals (Tensors) – 传递给
func
的位置参数必须都是张量。返回的函数也会计算这些参数的导数。 -
has_aux (bool) – 标志,表示
func
返回一个包含两个元素的元组:(output, aux)
。其中第一个元素是需要求导的函数输出,第二个元素是其他辅助对象且不会被求导。默认值:False。
-
- 返回值
-
返回一个包含
func
应用于primals
的结果和计算func
相对于所有primals
的vjp(反向模式自动微分)函数vjp_fn
的元组。如果has_aux
为True
,则返回一个包含辅助输出的元组(output, vjp_fn, aux)
。其中,返回的vjp_fn
函数将返回每个VJP的结果。
>>> x = torch.randn([5]) >>> f = lambda x: x.sin().sum() >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> grad = vjpfunc(torch.tensor(1.))[0] >>> assert torch.allclose(grad, torch.func.grad(f)(x))
然而,
vjp()
可以通过为每个输出传递余切来支持具有多个输出的函数。>>> x = torch.randn([5]) >>> f = lambda x: (x.sin(), x.cos()) >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> vjps = vjpfunc((torch.ones([5]), torch.ones([5]))) >>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
vjp()
甚至可以支持输出为 Python 结构体>>> x = torch.randn([5]) >>> f = lambda x: {'first': x.sin(), 'second': x.cos()} >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])} >>> vjps = vjpfunc(cotangents) >>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
vjp()
函数返回的函数将计算每个primals
的偏导数。>>> x, y = torch.randn([5, 4]), torch.randn([4, 5]) >>> (_, vjpfunc) = torch.func.vjp(torch.matmul, x, y) >>> cotangents = torch.randn([5, 5]) >>> vjps = vjpfunc(cotangents) >>> assert len(vjps) == 2 >>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1))) >>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents))
primals
是f
的位置参数,而所有关键字参数都将使用它们的默认值。>>> x = torch.randn([5]) >>> def f(x, scale=4.): >>> return x * scale >>> >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> vjps = vjpfunc(torch.ones_like(x)) >>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.))
注意
结合使用 PyTorch 的
torch.no_grad
和vjp
。情况 1:在函数内部使用torch.no_grad
:>>> def f(x): >>> with torch.no_grad(): >>> c = x ** 2 >>> return x - c
在这种情况下,
vjp(f)(x)
会遵守内部的torch.no_grad
。情况2:在
torch.no_grad
上下文中使用vjp
:>>> with torch.no_grad(): >>> vjp(f)(x)
在这种情况下,
vjp
尊重内部的torch.no_grad
,但不尊重外部的。这是因为vjp
是一个“函数变换”:它的结果不应依赖于f
之外的上下文管理器。