Jacobians、Hessians、hvp、vhp 等:函数变换的组合
计算雅可比矩阵或海森矩阵在许多非传统的深度学习模型中非常有用。使用 PyTorch 的常规自动微分 API(Tensor.backward()
、torch.autograd.grad
)高效地计算这些量是困难(或繁琐)的。PyTorch 的 JAX 启发的 函数变换 API 提供了高效计算各种高阶自动微分量的方法。
本教程要求 PyTorch 2.0.0 或更高版本。
计算雅可比矩阵
让我们从一个我们想要计算其雅可比矩阵的函数开始。这是一个带有非线性激活的简单线性函数。
让我们添加一些虚拟数据:一个权重、一个偏置和一个特征向量 x。
让我们将 predict
视为一个将输入 x
从 映射的函数。PyTorch Autograd 计算的是向量-雅可比积。为了计算这个 函数的完整雅可比矩阵,我们需要每次使用不同的单位向量逐行计算。
与其逐行计算雅可比矩阵,我们可以使用 PyTorch 的 torch.vmap
函数变换来消除 for 循环并向量化计算。我们不能直接将 vmap
应用于 torch.autograd.grad
;相反,PyTorch 提供了一个可以与 torch.vmap
组合使用的 torch.func.vjp
变换:
在后续的教程中,反向模式自动微分(reverse-mode AD)与vmap
的组合将为我们提供逐样本梯度。而在本教程中,反向模式自动微分与vmap
的组合则用于计算雅可比矩阵!vmap
与自动微分变换的各种组合可以为我们提供不同的有趣结果。
PyTorch提供了torch.func.jacrev
作为一个便捷函数,它执行vmap-vjp
组合来计算雅可比矩阵。jacrev
接受一个argnums
参数,用于指定我们希望计算雅可比矩阵的输入参数。
让我们比较两种计算雅可比矩阵的方法的性能。函数变换版本要快得多(而且输出的数量越多,速度就越快)。
通常,我们期望通过 vmap
进行向量化可以帮助消除开销,并更好地利用硬件。
vmap
通过将外部循环下推到函数的原始操作中来实现这种魔法,从而获得更好的性能。
让我们创建一个快速函数来评估性能,并处理微秒和毫秒的测量:
然后运行性能比较:
让我们使用 get_perf
函数对上述内容进行相对性能比较:
此外,我们可以轻松地将问题反过来,即我们希望计算模型参数(权重、偏置)的雅可比矩阵,而不是输入的雅可比矩阵。
反向模式雅可比矩阵 (jacrev
) 与正向模式雅可比矩阵 (jacfwd
)
我们提供了两个 API 来计算雅可比矩阵:jacrev
和 jacfwd
:
-
jacrev
使用反向模式自动微分。正如您在上面看到的,它是我们的vjp
和vmap
变换的组合。 -
jacfwd
使用前向模式自动微分。它是通过我们的jvp
和vmap
变换的组合实现的。
jacfwd
和 jacrev
可以互相替代,但它们的性能特征不同。
一般来说,如果您正在计算一个 函数的雅可比矩阵,并且输出的数量远多于输入的数量(例如,),那么 jacfwd
是更优的选择,否则则使用 jacrev
。虽然这个规则也有例外,但非严格的论证如下:
在反向模式自动微分(AD)中,我们是逐行计算雅可比矩阵的,而在正向模式 AD(计算雅可比向量积)中,我们是逐列计算的。雅可比矩阵有 M 行和 N 列,因此如果矩阵在某一个方向上更高或更宽,我们可能会倾向于选择处理更少行或列的方法。
首先,让我们进行输入多于输出的基准测试:
然后进行相对基准测试:
现在情况相反 - 输出数量(M)多于输入数量(N):
以及相对性能比较:
使用 functorch.hessian 进行 Hessian 计算
我们提供了一个便捷的 API 来计算 Hessian 矩阵:torch.func.hessiani
。Hessian 矩阵是 Jacobian 矩阵的 Jacobian 矩阵(或者说偏导数的偏导数,即二阶导数)。
这意味着你可以简单地组合 functorch 的 Jacobian 变换来计算 Hessian 矩阵。实际上,在底层,hessian(f)
就是 jacfwd(jacrev(f))
。
注意:为了提高性能:根据你的模型,你也可以使用 jacfwd(jacfwd(f))
或 jacrev(jacrev(f))
来计算 Hessian 矩阵,利用上述关于宽矩阵与高矩阵的经验法则。
让我们验证一下,无论使用 Hessian API 还是 jacfwd(jacfwd())
,结果是否相同。
批量雅可比矩阵和批量海森矩阵
在上述示例中,我们一直在处理单个特征向量。在某些情况下,您可能希望计算一批输出相对于一批输入的雅可比矩阵。也就是说,给定形状为 (B, N)
的一批输入和一个从 的函数,我们希望得到一个形状为 (B, M, N)
的雅可比矩阵。
最简单的方法是使用 vmap
:
如果您有一个函数从 (B, N) -> (B, M) 转换,并且确信每个输入都会产生独立的输出,那么有时也可以在不使用 vmap
的情况下实现这一点,即先对输出求和,然后计算该函数的雅可比矩阵:
如果您有一个从 的函数,但输入是批处理的,您可以将 vmap
与 jacrev
组合起来以计算批处理雅可比矩阵:
最后,批量海森矩阵也可以用类似的方式计算。最简单的方法是使用 vmap
对海森矩阵计算进行批处理,但在某些情况下,求和技巧同样适用。
计算 Hessian-向量积
计算 Hessian-向量积(hvp)的简单方法是显式生成完整的 Hessian 矩阵并与向量进行点积。我们可以做得更好:事实证明,我们不需要显式生成完整的 Hessian 矩阵也能完成这一计算。我们将介绍两种(众多)不同的策略来计算 Hessian-向量积:
- 将反向模式自动微分(AD)与反向模式 AD 结合使用
- 将反向模式 AD 与正向模式 AD 结合使用
将反向模式 AD 与正向模式 AD 结合使用(而不是反向模式与反向模式结合),通常是计算 Hessian-向量积更节省内存的方式,因为正向模式 AD 不需要构建 Autograd 计算图并保存中间结果以便反向传播:
这是一个示例用法。
如果 PyTorch 的前向自动微分(forward-AD)不支持您的操作,那么我们可以改用反向模式自动微分(reverse-mode AD)与反向模式自动微分进行组合: