torch.func API 参考
函数变换
vmap 是一个向量化的映射函数; |
|
|
|
返回一个函数,用于计算梯度和原生正向计算的结果元组。 |
|
表示向量-Jacobian 积,返回一个元组。该元组包含将 |
|
代表雅可比向量积,返回一个元组,包含func(*primals)的输出以及在 表示雅可比向量积,返回一个元组。该元组包含函数func(*primals)的输出以及在点 |
|
返回 |
|
使用反向模式自动微分,计算 |
|
使用前向模式自动微分,计算 |
|
通过正向和逆向策略计算 |
|
functionalize 是一种转换,可用于从函数中移除(中间)变异和别名,同时保持函数的语义不变。 |
用于操作torch.nn.Module的工具
通常,你可以对调用torch.nn.Module
的函数进行变换。例如,下面是一个计算一个接受三个输入值并返回三个输出值的函数的雅可比矩阵的例子:
model = torch.nn.Linear(3, 3) def f(x): return model(x) x = torch.randn(3) jacobian = jacrev(f)(x) assert jacobian.shape == (3, 3)
然而,如果你想计算模型参数的雅可比矩阵,就需要有一种方法来构造以这些参数为输入的函数。这就是functional_call()
的作用:它接受一个nn.Module、转换后的parameters
以及模块前向传递的输入。返回值是使用替换参数运行模块前向传递的结果。
这里是如何计算参数的雅可比矩阵
model = torch.nn.Linear(3, 3) def f(params, x): return torch.func.functional_call(model, params, x) x = torch.randn(3) jacobian = jacrev(f)(dict(model.named_parameters()), x)
使用提供的参数和缓冲区替换模块的原有参数和缓冲区,从而执行功能调用。 |
|
为集成准备一个 torch.nn.Modules 列表,使用 |
|
通过将 |
如果你想了解如何修复Batch Norm模块的相关信息,请参阅这里的指导。