torch.func API 参考

函数变换

vmap

vmap 是一个向量化的映射函数;vmap(func) 会返回一个新的函数,这个新函数会对输入的某一个维度应用 func 函数。

grad

grad 操作符有助于计算 func 相对于由 argnums 指定的输入的梯度。

grad_and_value

返回一个函数,用于计算梯度和原生正向计算的结果元组。

vjp

表示向量-Jacobian 积,返回一个元组。该元组包含将 func 应用于 primals 的结果以及一个函数。当给定 cotangents 时,这个函数会计算 func 关于 primals 的反向模式 Jacobian,并将其与 cotangents 相乘。

jvp

代表雅可比向量积,返回一个元组,包含func(*primals)的输出以及在primals处评估的func的“雅可比矩阵”与tangents的乘积。

或者更自然一些:

表示雅可比向量积,返回一个元组。该元组包含函数func(*primals)的输出以及在点primals处计算的func的“雅可比矩阵”与tangents的乘积。

linearize

返回funcprimals处的值及其一阶近似。

jacrev

使用反向模式自动微分,计算func相对于索引argnum参数的雅可比矩阵。

jacfwd

使用前向模式自动微分,计算func相对于索引argnum处参数的雅可比矩阵。

hessian

通过正向和逆向策略计算func在索引argnum处参数的Hessian矩阵。

functionalize

functionalize 是一种转换,可用于从函数中移除(中间)变异和别名,同时保持函数的语义不变。

用于操作torch.nn.Module的工具

通常,你可以对调用torch.nn.Module的函数进行变换。例如,下面是一个计算一个接受三个输入值并返回三个输出值的函数的雅可比矩阵的例子:

model = torch.nn.Linear(3, 3)

def f(x):
    return model(x)

x = torch.randn(3)
jacobian = jacrev(f)(x)
assert jacobian.shape == (3, 3)

然而,如果你想计算模型参数的雅可比矩阵,就需要有一种方法来构造以这些参数为输入的函数。这就是functional_call()的作用:它接受一个nn.Module、转换后的parameters以及模块前向传递的输入。返回值是使用替换参数运行模块前向传递的结果。

这里是如何计算参数的雅可比矩阵

model = torch.nn.Linear(3, 3)

def f(params, x):
    return torch.func.functional_call(model, params, x)

x = torch.randn(3)
jacobian = jacrev(f)(dict(model.named_parameters()), x)

functional_call

使用提供的参数和缓冲区替换模块的原有参数和缓冲区,从而执行功能调用。

stack_module_state

为集成准备一个 torch.nn.Modules 列表,使用 vmap()

replace_all_batch_norm_modules_

通过将running_meanrunning_var设置为None,并将root中任何nn.BatchNorm模块的track_running_stats参数设置为False来进行就地更新。

如果你想了解如何修复Batch Norm模块的相关信息,请参阅这里的指导。

本页目录