使用 autograd.Function 扩展 torch.func

您希望使用 torch.autograd.Function 与诸如 torch.vmap()torch.func.grad()torch.func 变换一起。

主要有两个使用场景:

  • 您希望调用不包含 PyTorch 操作的代码,并使其与这些函数变换兼容。也就是说,torch.autograd.Function 的前向、后向等方法会调用其他系统(如 C++、CUDA、NumPy)中的函数。

  • 您希望指定自定义梯度规则,例如 JAX 的 custom_vjpcustom_jvp

PyTorch 将这两个概念整合到了 torch.autograd.Function 中。

基本用法

本指南假设您已经熟悉 扩展 torch.autograd,该部分详细介绍了如何使用 torch.autograd.Function

[`torch.autograd.Function`](../autograd.html#torch.autograd.Function) 可以有一个接受 `ctx` 对象的 [`forward()`](../generated/torch.autograd.Function.forward.html#torch.autograd.Function.forward),也可以有一个不接受 `ctx` 的单独的 [`forward()`](../generated/torch.autograd.Function.forward.html#torch.autograd.Function.forward) 和一个修改 `ctx` 对象的 `setup_context()` 静态方法。

只有后者支持函数变换:

* [`forward()`](../generated/torch.autograd.Function.forward.html#torch.autograd.Function.forward) 是执行操作的代码,它不应接受 `ctx` 对象。

* `setup_context(ctx, inputs, output)` 是调用 `ctx` 方法的地方。在这里,你应该保存用于反向传播的张量(通过调用 `ctx.save_for_backward(*tensors)`),或保存非张量对象(通过将其分配给 `ctx` 对象)。

由于 `setup_context()` 只接受 `inputs` 和 `output`,唯一可以保存的是输入或输出中的对象(例如张量),或者从这些对象中派生的数量(例如 `Tensor.shape`)。如果你希望在 `Function.forward()` 中保存一个非输入的中间激活值以供反向传播使用,需要将其作为输出返回,以便传递给 `setup_context()`。

取决于变换,

为了使 torch.autograd.Function 能够与函数转换任意组合,我们建议除了 forward()setup_context() 之外的所有其他静态方法都必须是可转换的:也就是说,它们必须仅包含 PyTorch 操作符或调用其他 torch.autograd.Function(这些函数可能会调用 C++/CUDA 等)。

让我们来看一些常见用例。

示例 1: autograd.Function 调用另一个系统

一个常见的场景是,torch.autograd.Functionforward()backward() 方法都调用了另一个系统(如 C++、CUDA、numpy、triton)。

现在,为了简化 NumpySort 的使用(隐藏返回的中间结果,并支持默认参数和关键字参数),我们创建了一个新的函数来调用它:

def numpy_sort(x, dim=-1):
    result, _, _ = NumpySort.apply(x, dim)
    return result

这是一个简单的检查:

x = torch.randn(2, 3)
grad_x = torch.func.grad(lambda x: numpy_sort(x).sum())(x)
assert torch.allclose(grad_x, torch.ones_like(x))

示例 2:autograd.Function 定义自定义梯度规则

另一种常见的情况是实现了一个使用 PyTorch 操作的 torch.autograd.Function。PyTorch 能够自动计算 PyTorch 操作的梯度,但也许我们希望自定义梯度的计算方式。我们可能希望自定义反向传播与 PyTorch 提供的不同的一些原因包括:

  • 改善数值稳定性
  • 优化反向传播的性能
  • 改进对边缘情况的处理(例如,NaN、Inf)
  • 调整梯度(例如,梯度裁剪)

这里是一个示例,展示了如何为函数 y = x ** 3 创建一个 torch.autograd.Function,我们在其中优化了性能表现(通常在反向传播中计算的 dx 现在在前向传播中完成)。

class MyCube(torch.autograd.Function):
    @staticmethod
    def forward(x):
        result = x ** 3
        # 在常规的 PyTorch 中,如果我们刚刚运行了 y = x ** 3,那么反向传播
        # 会计算 dx = 3 * x ** 2。在这个 autograd.Function 中,我们
        # 在前向传播过程中完成了这个计算。
        dx = 3 * x ** 2
        return result, dx

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, = inputs
        result, dx = output
        ctx.save_for_backward(x, dx)

    @staticmethod
    def backward(ctx, grad_output, grad_dx):
        x, dx = ctx.saved_tensors
        # 为了使 autograd.Function 能够处理高阶导数,我们需要加上 `dx` 的梯度贡献。
        result = grad_output * dx + grad_dx * 6 * x
        return result

现在,为了更方便地使用 NumpySort(并隐藏我们作为输出返回的中间结果),我们创建了一个新函数来调用它:

def my_cube(x):
    result, _ = MyCube.apply(x)
    return result

这是一个计算二阶梯度的简单检查:

x = torch.randn([])
ggx = torch.func.grad(torch.func.grad(my_cube))(x)
assert torch.allclose(ggx, 6 * x)

限制和注意事项

警告

请仔细阅读这些关于使用 torch.func 变换时的 torch.autograd.Function 的限制。我们无法捕获许多这些情况并优雅地处理错误,因此它们会导致未定义的行为。

请不要在 torch.autograd.Function 的方法中使用正在被变换的张量、requires_grad=True 的张量或双精度张量。确保在 torch.autograd.Function 的任何方法中使用的张量必须直接作为输入传递(或通过 ctx 对象),而不是从外部传入。

[`torch.autograd.Function`](../autograd.html#torch.autograd.Function) 不处理 pytrees(可能包含或不包含 Tensors 的任意嵌套 Python 数据结构)。为了使这些 Tensors 被 autograd 跟踪,它们必须直接作为参数传递给 [`torch.autograd.Function`](../autograd.html#torch.autograd.Function)。这与 jax.{custom\_vjp, custom\_jvp} 不同,后者确实接受 pytrees。

请仅使用 [`save_for_backward()`](../generated/torch.autograd.function.FunctionCtx.save_for_backward.html#torch.autograd.function.FunctionCtx.save_for_backward) 或 `save_for_forward()` 来保存 Tensors。请不要直接在 ctx 对象上赋值 Tensors 或 Tensors 的集合——这些 Tensors 将不会被跟踪。

## [`torch.vmap()`](../generated/torch.vmap.html#torch.vmap) 支持

要使用 [`torch.autograd.Function`](../autograd.html#torch.autograd.Function) 与 [`torch.vmap()`](../generated/torch.vmap.html#torch.vmap),您必须:

* 提供一个 `vmap()` 静态方法,说明 `torch.autograd.Function` 在 `torch.vmap()` 下的行为。
* 通过设置 `generate_vmap_rule=True` 请求自动生成该规则。

自动生成 vmap 规则

如果您 torch.autograd.Function 满足以下附加约束条件,那么我们可以为其生成一个 vmap 规则。如果不满足这些约束条件或需要在 vmap 下实现自定义行为,请手动定义一个 vmap 静态方法(参见下一节)。

警告

我们无法轻松检查这些约束条件并优雅地报错。违反这些约束条件可能导致未定义的行为。

  • torch.autograd.Functionforward()backward()(如果存在)和 jvp()(如果存在)静态方法必须可以通过 torch.vmap() 进行转换。也就是说,它们必须仅包含 PyTorch 操作(而不是 NumPy 或自定义 CUDA 内核等)。

示例:

class MyCube(torch.autograd.Function):
    # 将 generate_vmap_rule 设置为 True,要求 PyTorch 自动生成 vmap 规则。
    generate_vmap_rule = True

    @staticmethod
    def forward(x):
        result = x ** 3
        dx = 3 * x ** 2
        return result, dx

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, = inputs
        result, dx = output
        ctx.save_for_backward(x, dx)

    @staticmethod
    def backward(ctx, grad_output, grad_dx):
        x, dx = ctx.saved_tensors
        result = grad_output * dx + grad_dx * 6 * x
        return result

def my_cube(x):
    result, dx = MyCube.apply(x)
    return result

x = torch.randn(3)
result = torch.vmap(my_cube)(x)
assert torch.allclose(result, x ** 3)

定义 vmap 静态方法

如果你的 torch.autograd.Function 调用了另一个系统(如 NumPy、C++、CUDA、triton),那么为了让它与 torch.vmap() 或使用它的转换一起工作,你需要手动定义一个 vmap() 静态方法。

根据您想要使用的变换和具体应用场景,您可能不需要为所有 torch.autograd.Function 添加 vmap() 静态方法:

  • 例如,torch.func.jacrev() 在反向传播过程中应用 vmap()。因此,如果您只对使用 torch.func.jacrev() 感兴趣,则只需确保 backward() 静态方法能够使用 vmap。

尽管如此,我们还是建议确保所有 torch.autograd.Function 都支持 torch.vmap() 功能,特别是如果您正在编写第三方库,并希望您的 torch.autograd.Function 能与所有 torch.func() 转换组合兼容。

从概念上来说,vmap 静态方法负责定义如何在 torch.vmap() 下执行 forward()。也就是说,它定义了如何将 forward() 转换为在带有额外维度的输入上运行,其中额外维度是指被 vmap 映射的维度。这与在 PyTorch 操作中实现 torch.vmap() 的方式类似:对于每个操作,我们定义一个 vmap 规则(有时也称为“批量处理规则”)。

以下是定义 vmap() 静态方法的方式:

  • 签名是 vmap(info, in_dims: Tuple[Optional[int]], *args),其中 *args 表示传递给 forward() 方法的所有参数。

  • vmap 静态方法定义了在使用 torch.vmap() 时,forward() 方法的行为。即,给定带有额外维度(由 in_dims 指定)的输入,如何计算 forward() 的批量版本?

  • 对于 args 中的每个参数,in_dims 中都有一个对应的 Optional[int]。如果参数不是 Tensor 或不参与 vmap 运算,则该值为 None;否则,它是一个整数,表示 Tensor 的哪个维度正在被 vmap 运算。

  • info 包含一些额外的元数据,这些元数据可能有助于操作。info.batch_size 表示正在被 vmap 运算的维度的大小,而 info.randomness 是传递给 torch.vmap()randomness 选项。

  • vmap 静态方法返回一个包含 (output, out_dims) 的元组。类似于 in_dimsout_dims 应该与 output 具有相同的结构,并且每个输出都包含一个 out_dim,用于指定该输出是否有 vmap 维度及其索引。

示例:

注意

vmap 静态方法应确保保留整个 Function 的语义。也就是说,(伪代码)grad(vmap(MyFunc)) 应该能够被 grad(map(MyFunc)) 替换。

如果你的 autograd.Function 在反向传播中有任何自定义行为,请记住这一点。

注意

对于 PyTorch 可以通过 generate_vmap_rule=True 生成 vmap 规则的 Function,编写自定义 vmap 静态方法是合法的。如果你希望生成的 vmap 规则不符合你的需求,你可能需要这样做。

torch.func.jvp() 支持

为了支持前向模式自动微分,torch.autograd.Function 必须实现 jvp() 静态方法。详情请参见 前向模式自动微分

本页目录