torch.func.vjp

torch.func.vjp(func, *primals, has_aux=False)

表示向量-Jacobian 积,返回一个元组。该元组包含将 func 应用于 primals 的结果以及一个函数。当给定 cotangents 时,这个函数会计算 func 关于 primals 的反向模式 Jacobian,并将其与 cotangents 相乘。

参数
  • func (Callable) – 一个接受一个或多个参数的 Python 函数,必须返回一个或多个张量。

  • primals (Tensors) – 传递给 func 的位置参数必须都是张量。返回的函数也会计算这些参数的导数。

  • has_aux (bool) – 标志,表示 func 返回一个包含两个元素的元组:(output, aux)。其中第一个元素是需要求导的函数输出,第二个元素是其他辅助对象且不会被求导。默认值:False。

返回值

返回一个包含func应用于primals的结果和计算func相对于所有primals的vjp(反向模式自动微分)函数vjp_fn的元组。如果has_auxTrue,则返回一个包含辅助输出的元组(output, vjp_fn, aux)。其中,返回的vjp_fn函数将返回每个VJP的结果。

在简单情况下,vjp() 的行为与 grad() 相同。

>>> x = torch.randn([5])
>>> f = lambda x: x.sin().sum()
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> grad = vjpfunc(torch.tensor(1.))[0]
>>> assert torch.allclose(grad, torch.func.grad(f)(x))

然而,vjp() 可以通过为每个输出传递余切来支持具有多个输出的函数。

>>> x = torch.randn([5])
>>> f = lambda x: (x.sin(), x.cos())
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> vjps = vjpfunc((torch.ones([5]), torch.ones([5])))
>>> assert torch.allclose(vjps[0], x.cos() + -x.sin())

vjp() 甚至可以支持输出为 Python 结构体

>>> x = torch.randn([5])
>>> f = lambda x: {'first': x.sin(), 'second': x.cos()}
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])}
>>> vjps = vjpfunc(cotangents)
>>> assert torch.allclose(vjps[0], x.cos() + -x.sin())

vjp() 函数返回的函数将计算每个 primals 的偏导数。

>>> x, y = torch.randn([5, 4]), torch.randn([4, 5])
>>> (_, vjpfunc) = torch.func.vjp(torch.matmul, x, y)
>>> cotangents = torch.randn([5, 5])
>>> vjps = vjpfunc(cotangents)
>>> assert len(vjps) == 2
>>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1)))
>>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents))

primalsf 的位置参数,而所有关键字参数都将使用它们的默认值。

>>> x = torch.randn([5])
>>> def f(x, scale=4.):
>>>   return x * scale
>>>
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> vjps = vjpfunc(torch.ones_like(x))
>>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.))

注意

结合使用 PyTorch 的 torch.no_gradvjp。情况 1:在函数内部使用 torch.no_grad

>>> def f(x):
>>>     with torch.no_grad():
>>>         c = x ** 2
>>>     return x - c

在这种情况下,vjp(f)(x) 会遵守内部的 torch.no_grad

情况2:在torch.no_grad上下文中使用vjp

>>> with torch.no_grad():
>>>     vjp(f)(x)

在这种情况下,vjp 尊重内部的 torch.no_grad,但不尊重外部的。这是因为 vjp 是一个“函数变换”:它的结果不应依赖于 f 之外的上下文管理器。

本页目录