torch.func.linearize

torch.func.linearize(func, *primals)

返回funcprimals处的值及其一阶近似。

参数
  • func (Callable) – 一个可以接受一个或多个参数的 Python 函数。

  • primals (Tensors) – 传递给 func 的位置参数必须都是张量。这些是进行线性近似的位置值。

返回值

返回一个包含 (output, jvp_fn) 的元组。其中,func 应用于 primals 的结果和一个函数,该函数计算在 primals 处评估的 func 的 JVP。

返回类型

Tuple[Any, Callable]

如果需要在primals处多次计算jvp,linearize非常有用。然而,为了实现这一点,linearize会保存中间计算结果,并且其内存需求高于直接应用jvp。因此,如果所有tangents都是已知的,使用vmap(jvp)可能比使用linearize更高效。

注意

linearize 会评估 func 两次。请为单次评估的实现提交一个 issue。

示例:
>>> import torch
>>> from torch.func import linearize
>>> def fn(x):
...     return x.sin()
...
>>> output, jvp_fn = linearize(fn, torch.zeros(3, 3))
>>> jvp_fn(torch.ones(3, 3))
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
>>>
本页目录