使用 autograd.Function 扩展 torch.func
您希望使用 torch.autograd.Function 与诸如 torch.vmap()、torch.func.grad() 等 torch.func 变换一起。
主要有两个使用场景:
-
您希望调用不包含 PyTorch 操作的代码,并使其与这些函数变换兼容。也就是说,
torch.autograd.Function的前向、后向等方法会调用其他系统(如 C++、CUDA、NumPy)中的函数。 -
您希望指定自定义梯度规则,例如 JAX 的
custom_vjp和custom_jvp。
PyTorch 将这两个概念整合到了 torch.autograd.Function 中。
基本用法
本指南假设您已经熟悉 扩展 torch.autograd,该部分详细介绍了如何使用 torch.autograd.Function。
[`torch.autograd.Function`](../autograd.html#torch.autograd.Function) 可以有一个接受 `ctx` 对象的 [`forward()`](../generated/torch.autograd.Function.forward.html#torch.autograd.Function.forward),也可以有一个不接受 `ctx` 的单独的 [`forward()`](../generated/torch.autograd.Function.forward.html#torch.autograd.Function.forward) 和一个修改 `ctx` 对象的 `setup_context()` 静态方法。
只有后者支持函数变换:
* [`forward()`](../generated/torch.autograd.Function.forward.html#torch.autograd.Function.forward) 是执行操作的代码,它不应接受 `ctx` 对象。
* `setup_context(ctx, inputs, output)` 是调用 `ctx` 方法的地方。在这里,你应该保存用于反向传播的张量(通过调用 `ctx.save_for_backward(*tensors)`),或保存非张量对象(通过将其分配给 `ctx` 对象)。
由于 `setup_context()` 只接受 `inputs` 和 `output`,唯一可以保存的是输入或输出中的对象(例如张量),或者从这些对象中派生的数量(例如 `Tensor.shape`)。如果你希望在 `Function.forward()` 中保存一个非输入的中间激活值以供反向传播使用,需要将其作为输出返回,以便传递给 `setup_context()`。
取决于变换,
-
要支持反向模式自动微分(
torch.func.grad(),torch.func.vjp()),torch.autograd.Function需要一个backward()静态方法。 -
要支持
torch.vmap(),torch.autograd.Function需要一个vmap()静态方法。 -
要支持
torch.func.jvp(),torch.autograd.Function需要一个jvp()静态方法。 -
要支持变换的组合(如
torch.func.jacrev(),torch.func.jacfwd(),torch.func.hessian())——你可能需要上述多个方法。
为了使 torch.autograd.Function 能够与函数转换任意组合,我们建议除了 forward() 和 setup_context() 之外的所有其他静态方法都必须是可转换的:也就是说,它们必须仅包含 PyTorch 操作符或调用其他 torch.autograd.Function(这些函数可能会调用 C++/CUDA 等)。
让我们来看一些常见用例。
示例 1: autograd.Function 调用另一个系统
一个常见的场景是,torch.autograd.Function 的 forward() 和 backward() 方法都调用了另一个系统(如 C++、CUDA、numpy、triton)。
现在,为了简化 NumpySort 的使用(隐藏返回的中间结果,并支持默认参数和关键字参数),我们创建了一个新的函数来调用它:
def numpy_sort(x, dim=-1):
result, _, _ = NumpySort.apply(x, dim)
return result
这是一个简单的检查:
x = torch.randn(2, 3)
grad_x = torch.func.grad(lambda x: numpy_sort(x).sum())(x)
assert torch.allclose(grad_x, torch.ones_like(x))
示例 2:autograd.Function 定义自定义梯度规则
另一种常见的情况是实现了一个使用 PyTorch 操作的 torch.autograd.Function。PyTorch 能够自动计算 PyTorch 操作的梯度,但也许我们希望自定义梯度的计算方式。我们可能希望自定义反向传播与 PyTorch 提供的不同的一些原因包括:
- 改善数值稳定性
- 优化反向传播的性能
- 改进对边缘情况的处理(例如,NaN、Inf)
- 调整梯度(例如,梯度裁剪)
这里是一个示例,展示了如何为函数 y = x ** 3 创建一个 torch.autograd.Function,我们在其中优化了性能表现(通常在反向传播中计算的 dx 现在在前向传播中完成)。
class MyCube(torch.autograd.Function):
@staticmethod
def forward(x):
result = x ** 3
# 在常规的 PyTorch 中,如果我们刚刚运行了 y = x ** 3,那么反向传播
# 会计算 dx = 3 * x ** 2。在这个 autograd.Function 中,我们
# 在前向传播过程中完成了这个计算。
dx = 3 * x ** 2
return result, dx
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
result, dx = output
ctx.save_for_backward(x, dx)
@staticmethod
def backward(ctx, grad_output, grad_dx):
x, dx = ctx.saved_tensors
# 为了使 autograd.Function 能够处理高阶导数,我们需要加上 `dx` 的梯度贡献。
result = grad_output * dx + grad_dx * 6 * x
return result
现在,为了更方便地使用 NumpySort(并隐藏我们作为输出返回的中间结果),我们创建了一个新函数来调用它:
def my_cube(x):
result, _ = MyCube.apply(x)
return result
这是一个计算二阶梯度的简单检查:
x = torch.randn([])
ggx = torch.func.grad(torch.func.grad(my_cube))(x)
assert torch.allclose(ggx, 6 * x)
限制和注意事项
警告
请仔细阅读这些关于使用 torch.func 变换时的 torch.autograd.Function 的限制。我们无法捕获许多这些情况并优雅地处理错误,因此它们会导致未定义的行为。
请不要在 torch.autograd.Function 的方法中使用正在被变换的张量、requires_grad=True 的张量或双精度张量。确保在 torch.autograd.Function 的任何方法中使用的张量必须直接作为输入传递(或通过 ctx 对象),而不是从外部传入。
[`torch.autograd.Function`](../autograd.html#torch.autograd.Function) 不处理 pytrees(可能包含或不包含 Tensors 的任意嵌套 Python 数据结构)。为了使这些 Tensors 被 autograd 跟踪,它们必须直接作为参数传递给 [`torch.autograd.Function`](../autograd.html#torch.autograd.Function)。这与 jax.{custom\_vjp, custom\_jvp} 不同,后者确实接受 pytrees。
请仅使用 [`save_for_backward()`](../generated/torch.autograd.function.FunctionCtx.save_for_backward.html#torch.autograd.function.FunctionCtx.save_for_backward) 或 `save_for_forward()` 来保存 Tensors。请不要直接在 ctx 对象上赋值 Tensors 或 Tensors 的集合——这些 Tensors 将不会被跟踪。
## [`torch.vmap()`](../generated/torch.vmap.html#torch.vmap) 支持
要使用 [`torch.autograd.Function`](../autograd.html#torch.autograd.Function) 与 [`torch.vmap()`](../generated/torch.vmap.html#torch.vmap),您必须:
* 提供一个 `vmap()` 静态方法,说明 `torch.autograd.Function` 在 `torch.vmap()` 下的行为。
* 通过设置 `generate_vmap_rule=True` 请求自动生成该规则。
自动生成 vmap 规则
如果您 torch.autograd.Function 满足以下附加约束条件,那么我们可以为其生成一个 vmap 规则。如果不满足这些约束条件或需要在 vmap 下实现自定义行为,请手动定义一个 vmap 静态方法(参见下一节)。
警告
我们无法轻松检查这些约束条件并优雅地报错。违反这些约束条件可能导致未定义的行为。
torch.autograd.Function的forward()、backward()(如果存在)和jvp()(如果存在)静态方法必须可以通过torch.vmap()进行转换。也就是说,它们必须仅包含 PyTorch 操作(而不是 NumPy 或自定义 CUDA 内核等)。
示例:
class MyCube(torch.autograd.Function):
# 将 generate_vmap_rule 设置为 True,要求 PyTorch 自动生成 vmap 规则。
generate_vmap_rule = True
@staticmethod
def forward(x):
result = x ** 3
dx = 3 * x ** 2
return result, dx
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
result, dx = output
ctx.save_for_backward(x, dx)
@staticmethod
def backward(ctx, grad_output, grad_dx):
x, dx = ctx.saved_tensors
result = grad_output * dx + grad_dx * 6 * x
return result
def my_cube(x):
result, dx = MyCube.apply(x)
return result
x = torch.randn(3)
result = torch.vmap(my_cube)(x)
assert torch.allclose(result, x ** 3)
定义 vmap 静态方法
如果你的 torch.autograd.Function 调用了另一个系统(如 NumPy、C++、CUDA、triton),那么为了让它与 torch.vmap() 或使用它的转换一起工作,你需要手动定义一个 vmap() 静态方法。
根据您想要使用的变换和具体应用场景,您可能不需要为所有 torch.autograd.Function 添加 vmap() 静态方法:
- 例如,
torch.func.jacrev()在反向传播过程中应用vmap()。因此,如果您只对使用torch.func.jacrev()感兴趣,则只需确保backward()静态方法能够使用 vmap。
尽管如此,我们还是建议确保所有 torch.autograd.Function 都支持 torch.vmap() 功能,特别是如果您正在编写第三方库,并希望您的 torch.autograd.Function 能与所有 torch.func() 转换组合兼容。
从概念上来说,vmap 静态方法负责定义如何在 torch.vmap() 下执行 forward()。也就是说,它定义了如何将 forward() 转换为在带有额外维度的输入上运行,其中额外维度是指被 vmap 映射的维度。这与在 PyTorch 操作中实现 torch.vmap() 的方式类似:对于每个操作,我们定义一个 vmap 规则(有时也称为“批量处理规则”)。
以下是定义 vmap() 静态方法的方式:
-
签名是
vmap(info, in_dims: Tuple[Optional[int]], *args),其中*args表示传递给forward()方法的所有参数。 -
vmap静态方法定义了在使用torch.vmap()时,forward()方法的行为。即,给定带有额外维度(由in_dims指定)的输入,如何计算forward()的批量版本? -
对于
args中的每个参数,in_dims中都有一个对应的Optional[int]。如果参数不是 Tensor 或不参与 vmap 运算,则该值为None;否则,它是一个整数,表示 Tensor 的哪个维度正在被 vmap 运算。 -
info包含一些额外的元数据,这些元数据可能有助于操作。info.batch_size表示正在被 vmap 运算的维度的大小,而info.randomness是传递给torch.vmap()的randomness选项。 -
vmap静态方法返回一个包含(output, out_dims)的元组。类似于in_dims,out_dims应该与output具有相同的结构,并且每个输出都包含一个out_dim,用于指定该输出是否有 vmap 维度及其索引。
示例:
注意
vmap 静态方法应确保保留整个 Function 的语义。也就是说,(伪代码)grad(vmap(MyFunc)) 应该能够被 grad(map(MyFunc)) 替换。
如果你的 autograd.Function 在反向传播中有任何自定义行为,请记住这一点。
注意
对于 PyTorch 可以通过 generate_vmap_rule=True 生成 vmap 规则的 Function,编写自定义 vmap 静态方法是合法的。如果你希望生成的 vmap 规则不符合你的需求,你可能需要这样做。
torch.func.jvp() 支持
为了支持前向模式自动微分,torch.autograd.Function 必须实现 jvp() 静态方法。详情请参见 前向模式自动微分。