torch.func.jvp
- torch.func.jvp(func, primals, tangents, *, strict=False, has_aux=False)
-
代表雅可比向量积,返回一个元组,包含func(*primals)的输出以及在
primals
处评估的func
的“雅可比矩阵”与tangents
的乘积。这被称为前向模式自动微分。- 参数
-
-
func(函数)– 一个接受一个或多个参数的Python函数,其中至少有一个参数是张量,并返回一个或多个张量。
-
primals (Tensors) – 传递给
func
的位置参数必须都是张量。返回的函数也会计算这些参数的导数。 -
tangents (Tensors)– 用于计算雅可比矩阵-向量积的“向量”。它必须与传递给
func
的输入具有相同的结构和大小。 -
has_aux (bool) – 标志,表示
func
返回一个包含两个元素的元组:(output, aux)
。其中第一个元素是需要求导的函数输出,第二个元素是其他辅助对象且不会被求导。默认值:False。
-
- 返回值
-
返回一个包含
func
在primals
处评估的输出和雅可比矩阵-向量积的元组(output, jvp_out)
。如果has_aux is True
,则返回一个包含输出、雅可比矩阵-向量积以及辅助数据的元组(output, jvp_out, aux)
。
注意
你可能会遇到这个API错误:“操作符X的前向模式自动微分未实现”。如果遇到这种情况,请提交一个错误报告,我们会优先处理。
jvp 在计算从 R^1 到 R^N 的函数的梯度时非常有用。
>>> from torch.func import jvp >>> x = torch.randn([]) >>> f = lambda x: x * torch.tensor([1., 2., 3]) >>> value, grad = jvp(f, (x,), (torch.tensor(1.),)) >>> assert torch.allclose(value, f(x)) >>> assert torch.allclose(grad, torch.tensor([1., 2, 3]))
jvp()
可以为具有多个输入的函数提供支持,方法是为每个输入传递切线。>>> from torch.func import jvp >>> x = torch.randn(5) >>> y = torch.randn(5) >>> f = lambda x, y: (x * y) >>> _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5))) >>> assert torch.allclose(output, x + y)