torch.func.linearize
- torch.func.linearize(func, *primals)
-
返回
func
在primals
处的值及其一阶近似。- 参数
-
-
func (Callable) – 一个可以接受一个或多个参数的 Python 函数。
-
primals (Tensors) – 传递给
func
的位置参数必须都是张量。这些是进行线性近似的位置值。
-
- 返回值
-
返回一个包含
(output, jvp_fn)
的元组。其中,func
应用于primals
的结果和一个函数,该函数计算在primals
处评估的func
的 JVP。 - 返回类型
如果需要在
primals
处多次计算jvp,linearize非常有用。然而,为了实现这一点,linearize会保存中间计算结果,并且其内存需求高于直接应用jvp。因此,如果所有tangents
都是已知的,使用vmap(jvp)可能比使用linearize更高效。注意
linearize 会评估
func
两次。请为单次评估的实现提交一个 issue。- 示例:
-
>>> import torch >>> from torch.func import linearize >>> def fn(x): ... return x.sin() ... >>> output, jvp_fn = linearize(fn, torch.zeros(3, 3)) >>> jvp_fn(torch.ones(3, 3)) tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]) >>>