PyTorch 入门指南
学习 PyTorch
图像和视频
音频
后端
强化学习
在生产环境中部署 PyTorch 模型
Profiling PyTorch
代码变换与FX
前端API
扩展 PyTorch
模型优化
并行和分布式训练
边缘端的 ExecuTorch
推荐系统
多模态

Jacobians、Hessians、hvp、vhp 等:函数变换的组合

计算雅可比矩阵或海森矩阵在许多非传统的深度学习模型中非常有用。使用 PyTorch 的常规自动微分 API(Tensor.backward()torch.autograd.grad)高效地计算这些量是困难(或繁琐)的。PyTorch 的 JAX 启发的 函数变换 API 提供了高效计算各种高阶自动微分量的方法。

本教程要求 PyTorch 2.0.0 或更高版本。

计算雅可比矩阵

importtorch
importtorch.nn.functionalasF
fromfunctoolsimport partial
_ = torch.manual_seed(0)

让我们从一个我们想要计算其雅可比矩阵的函数开始。这是一个带有非线性激活的简单线性函数。

defpredict(weight, bias, x):
    return F.linear(x, weight, bias).tanh()

让我们添加一些虚拟数据:一个权重、一个偏置和一个特征向量 x。

D = 16
weight = torch.randn(D, D)
bias = torch.randn(D)
x = torch.randn(D)  # feature vector

让我们将 predict 视为一个将输入 x 从 \(R^D \to R^D\) 映射的函数。PyTorch Autograd 计算的是向量-雅可比积。为了计算这个 \(R^D \to R^D\) 函数的完整雅可比矩阵,我们需要每次使用不同的单位向量逐行计算。

defcompute_jac(xp):
    jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]
                     for vec in unit_vectors]
    return torch.stack(jacobian_rows)

xp = x.clone().requires_grad_()
unit_vectors = torch.eye(D)

jacobian = compute_jac(xp)

print(jacobian.shape)
print(jacobian[0])  # show first row
torch.Size([16, 16])
tensor([-0.5956, -0.6096, -0.1326, -0.2295,  0.4490,  0.3661, -0.1672, -1.1190,
         0.1705, -0.6683,  0.1851,  0.1630,  0.0634,  0.6547,  0.5908, -0.1308])

与其逐行计算雅可比矩阵,我们可以使用 PyTorch 的 torch.vmap 函数变换来消除 for 循环并向量化计算。我们不能直接将 vmap 应用于 torch.autograd.grad;相反,PyTorch 提供了一个可以与 torch.vmap 组合使用的 torch.func.vjp 变换:

fromtorch.funcimport vmap, vjp

_, vjp_fn = vjp(partial(predict, weight, bias), x)

ft_jacobian, = vmap(vjp_fn)(unit_vectors)

# let's confirm both methods compute the same result
assert torch.allclose(ft_jacobian, jacobian)

在后续的教程中,反向模式自动微分(reverse-mode AD)与vmap的组合将为我们提供逐样本梯度。而在本教程中,反向模式自动微分与vmap的组合则用于计算雅可比矩阵!vmap与自动微分变换的各种组合可以为我们提供不同的有趣结果。

PyTorch提供了torch.func.jacrev作为一个便捷函数,它执行vmap-vjp组合来计算雅可比矩阵。jacrev接受一个argnums参数,用于指定我们希望计算雅可比矩阵的输入参数。

fromtorch.funcimport jacrev

ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)

# Confirm by running the following:
assert torch.allclose(ft_jacobian, jacobian)

让我们比较两种计算雅可比矩阵的方法的性能。函数变换版本要快得多(而且输出的数量越多,速度就越快)。

通常,我们期望通过 vmap 进行向量化可以帮助消除开销,并更好地利用硬件。

vmap 通过将外部循环下推到函数的原始操作中来实现这种魔法,从而获得更好的性能。

让我们创建一个快速函数来评估性能,并处理微秒和毫秒的测量:

defget_perf(first, first_descriptor, second, second_descriptor):
"""takes torch.benchmark objects and compares delta of second vs first."""
    faster = second.times[0]
    slower = first.times[0]
    gain = (slower-faster)/slower
    if gain < 0: gain *=-1
    final_gain = gain*100
    print(f" Performance delta: {final_gain:.4f} percent improvement with {second_descriptor} ")

然后运行性能比较:

fromtorch.utils.benchmarkimport Timer

without_vmap = Timer(stmt="compute_jac(xp)", globals=globals())
with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())

no_vmap_timer = without_vmap.timeit(500)
with_vmap_timer = with_vmap.timeit(500)

print(no_vmap_timer)
print(with_vmap_timer)
<torch.utils.benchmark.utils.common.Measurement object at 0x7f706cb9f580>
compute_jac(xp)
  2.99 ms
  1 measurement, 500 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f7065c07fa0>
jacrev(predict, argnums=2)(weight, bias, x)
  720.16 us
  1 measurement, 500 runs , 1 thread

让我们使用 get_perf 函数对上述内容进行相对性能比较:

get_perf(no_vmap_timer, "without vmap",  with_vmap_timer, "vmap")
Performance delta: 75.9370 percent improvement with vmap

此外,我们可以轻松地将问题反过来,即我们希望计算模型参数(权重、偏置)的雅可比矩阵,而不是输入的雅可比矩阵。

# note the change in input via ``argnums`` parameters of 0,1 to map to weight and bias
ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)

反向模式雅可比矩阵 (jacrev) 与正向模式雅可比矩阵 (jacfwd)

我们提供了两个 API 来计算雅可比矩阵:jacrevjacfwd

  • jacrev 使用反向模式自动微分。正如您在上面看到的,它是我们的 vjpvmap 变换的组合。

  • jacfwd 使用前向模式自动微分。它是通过我们的 jvpvmap 变换的组合实现的。

jacfwdjacrev 可以互相替代,但它们的性能特征不同。

一般来说,如果您正在计算一个 \(R^N \to R^M\) 函数的雅可比矩阵,并且输出的数量远多于输入的数量(例如,\(M > N\)),那么 jacfwd 是更优的选择,否则则使用 jacrev。虽然这个规则也有例外,但非严格的论证如下:

在反向模式自动微分(AD)中,我们是逐行计算雅可比矩阵的,而在正向模式 AD(计算雅可比向量积)中,我们是逐列计算的。雅可比矩阵有 M 行和 N 列,因此如果矩阵在某一个方向上更高或更宽,我们可能会倾向于选择处理更少行或列的方法。

fromtorch.funcimport jacrev, jacfwd

首先,让我们进行输入多于输出的基准测试:

Din = 32
Dout = 2048
weight = torch.randn(Dout, Din)

bias = torch.randn(Dout)
x = torch.randn(Din)

# remember the general rule about taller vs wider... here we have a taller matrix:
print(weight.shape)

using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())

jacfwd_timing = using_fwd.timeit(500)
jacrev_timing = using_bwd.timeit(500)

print(f'jacfwd time: {jacfwd_timing}')
print(f'jacrev time: {jacrev_timing}')
torch.Size([2048, 32])
jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f70ac720df0>
jacfwd(predict, argnums=2)(weight, bias, x)
  1.30 ms
  1 measurement, 500 runs , 1 thread
jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f7065cfe200>
jacrev(predict, argnums=2)(weight, bias, x)
  9.33 ms
  1 measurement, 500 runs , 1 thread

然后进行相对基准测试:

get_perf(jacfwd_timing, "jacfwd", jacrev_timing, "jacrev", );
Performance delta: 618.5225 percent improvement with jacrev

现在情况相反 - 输出数量(M)多于输入数量(N):

Din = 2048
Dout = 32
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)

using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())

jacfwd_timing = using_fwd.timeit(500)
jacrev_timing = using_bwd.timeit(500)

print(f'jacfwd time: {jacfwd_timing}')
print(f'jacrev time: {jacrev_timing}')
jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f706530b700>
jacfwd(predict, argnums=2)(weight, bias, x)
  6.57 ms
  1 measurement, 500 runs , 1 thread
jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f706ccdfd30>
jacrev(predict, argnums=2)(weight, bias, x)
  833.49 us
  1 measurement, 500 runs , 1 thread

以及相对性能比较:

get_perf(jacrev_timing, "jacrev", jacfwd_timing, "jacfwd")
Performance delta: 688.8138 percent improvement with jacfwd

使用 functorch.hessian 进行 Hessian 计算

我们提供了一个便捷的 API 来计算 Hessian 矩阵:torch.func.hessiani。Hessian 矩阵是 Jacobian 矩阵的 Jacobian 矩阵(或者说偏导数的偏导数,即二阶导数)。

这意味着你可以简单地组合 functorch 的 Jacobian 变换来计算 Hessian 矩阵。实际上,在底层,hessian(f) 就是 jacfwd(jacrev(f))

注意:为了提高性能:根据你的模型,你也可以使用 jacfwd(jacfwd(f))jacrev(jacrev(f)) 来计算 Hessian 矩阵,利用上述关于宽矩阵与高矩阵的经验法则。

fromtorch.funcimport hessian

# lets reduce the size in order not to overwhelm Colab. Hessians require
# significant memory:
Din = 512
Dout = 32
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)

hess_api = hessian(predict, argnums=2)(weight, bias, x)
hess_fwdfwd = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)
hess_revrev = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)

让我们验证一下,无论使用 Hessian API 还是 jacfwd(jacfwd()),结果是否相同。

torch.allclose(hess_api, hess_fwdfwd)
True

批量雅可比矩阵和批量海森矩阵

在上述示例中,我们一直在处理单个特征向量。在某些情况下,您可能希望计算一批输出相对于一批输入的雅可比矩阵。也就是说,给定形状为 (B, N) 的一批输入和一个从 \(R^N \to R^M\) 的函数,我们希望得到一个形状为 (B, M, N) 的雅可比矩阵。

最简单的方法是使用 vmap

batch_size = 64
Din = 31
Dout = 33

weight = torch.randn(Dout, Din)
print(f"weight shape = {weight.shape}")

bias = torch.randn(Dout)

x = torch.randn(batch_size, Din)

compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0))
batch_jacobian0 = compute_batch_jacobian(weight, bias, x)
weight shape = torch.Size([33, 31])

如果您有一个函数从 (B, N) -> (B, M) 转换,并且确信每个输入都会产生独立的输出,那么有时也可以在不使用 vmap 的情况下实现这一点,即先对输出求和,然后计算该函数的雅可比矩阵:

defpredict_with_output_summed(weight, bias, x):
    return predict(weight, bias, x).sum(0)

batch_jacobian1 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x).movedim(1, 0)
assert torch.allclose(batch_jacobian0, batch_jacobian1)

如果您有一个从 \(R^N \to R^M\) 的函数,但输入是批处理的,您可以将 vmapjacrev 组合起来以计算批处理雅可比矩阵:

最后,批量海森矩阵也可以用类似的方式计算。最简单的方法是使用 vmap 对海森矩阵计算进行批处理,但在某些情况下,求和技巧同样适用。

compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0))

batch_hess = compute_batch_hessian(weight, bias, x)
batch_hess.shape
torch.Size([64, 33, 31, 31])

计算 Hessian-向量积

计算 Hessian-向量积(hvp)的简单方法是显式生成完整的 Hessian 矩阵并与向量进行点积。我们可以做得更好:事实证明,我们不需要显式生成完整的 Hessian 矩阵也能完成这一计算。我们将介绍两种(众多)不同的策略来计算 Hessian-向量积:

  • 将反向模式自动微分(AD)与反向模式 AD 结合使用
  • 将反向模式 AD 与正向模式 AD 结合使用

将反向模式 AD 与正向模式 AD 结合使用(而不是反向模式与反向模式结合),通常是计算 Hessian-向量积更节省内存的方式,因为正向模式 AD 不需要构建 Autograd 计算图并保存中间结果以便反向传播:

fromtorch.funcimport jvp, grad, vjp

defhvp(f, primals, tangents):
  return jvp(grad(f), primals, tangents)[1]

这是一个示例用法。

deff(x):
  return x.sin().sum()

x = torch.randn(2048)
tangent = torch.randn(2048)

result = hvp(f, (x,), (tangent,))

如果 PyTorch 的前向自动微分(forward-AD)不支持您的操作,那么我们可以改用反向模式自动微分(reverse-mode AD)与反向模式自动微分进行组合:

defhvp_revrev(f, primals, tangents):
  _, vjp_fn = vjp(grad(f), *primals)
  return vjp_fn(*tangents)

result_hvp_revrev = hvp_revrev(f, (x,), (tangent,))
assert torch.allclose(result, result_hvp_revrev[0])
本页目录