使用 autograd.Function 扩展 torch.func
您希望使用 torch.autograd.Function
与诸如 torch.vmap()
、torch.func.grad()
等 torch.func
变换一起。
主要有两个使用场景:
-
您希望调用不包含 PyTorch 操作的代码,并使其与这些函数变换兼容。也就是说,
torch.autograd.Function
的前向、后向等方法会调用其他系统(如 C++、CUDA、NumPy)中的函数。 -
您希望指定自定义梯度规则,例如 JAX 的
custom_vjp
和custom_jvp
。
PyTorch 将这两个概念整合到了 torch.autograd.Function
中。
基本用法
本指南假设您已经熟悉 扩展 torch.autograd,该部分详细介绍了如何使用 torch.autograd.Function
。
[`torch.autograd.Function`](../autograd.html#torch.autograd.Function) 可以有一个接受 `ctx` 对象的 [`forward()`](../generated/torch.autograd.Function.forward.html#torch.autograd.Function.forward),也可以有一个不接受 `ctx` 的单独的 [`forward()`](../generated/torch.autograd.Function.forward.html#torch.autograd.Function.forward) 和一个修改 `ctx` 对象的 `setup_context()` 静态方法。
只有后者支持函数变换:
* [`forward()`](../generated/torch.autograd.Function.forward.html#torch.autograd.Function.forward) 是执行操作的代码,它不应接受 `ctx` 对象。
* `setup_context(ctx, inputs, output)` 是调用 `ctx` 方法的地方。在这里,你应该保存用于反向传播的张量(通过调用 `ctx.save_for_backward(*tensors)`),或保存非张量对象(通过将其分配给 `ctx` 对象)。
由于 `setup_context()` 只接受 `inputs` 和 `output`,唯一可以保存的是输入或输出中的对象(例如张量),或者从这些对象中派生的数量(例如 `Tensor.shape`)。如果你希望在 `Function.forward()` 中保存一个非输入的中间激活值以供反向传播使用,需要将其作为输出返回,以便传递给 `setup_context()`。
取决于变换,
-
要支持反向模式自动微分(
torch.func.grad()
,torch.func.vjp()
),torch.autograd.Function
需要一个backward()
静态方法。 -
要支持
torch.vmap()
,torch.autograd.Function
需要一个vmap()
静态方法。 -
要支持
torch.func.jvp()
,torch.autograd.Function
需要一个jvp()
静态方法。 -
要支持变换的组合(如
torch.func.jacrev()
,torch.func.jacfwd()
,torch.func.hessian()
)——你可能需要上述多个方法。
为了使 torch.autograd.Function
能够与函数转换任意组合,我们建议除了 forward()
和 setup_context()
之外的所有其他静态方法都必须是可转换的:也就是说,它们必须仅包含 PyTorch 操作符或调用其他 torch.autograd.Function
(这些函数可能会调用 C++/CUDA 等)。
让我们来看一些常见用例。
示例 1: autograd.Function 调用另一个系统
一个常见的场景是,torch.autograd.Function
的 forward()
和 backward()
方法都调用了另一个系统(如 C++、CUDA、numpy、triton)。
现在,为了简化 NumpySort
的使用(隐藏返回的中间结果,并支持默认参数和关键字参数),我们创建了一个新的函数来调用它:
def numpy_sort(x, dim=-1):
result, _, _ = NumpySort.apply(x, dim)
return result
这是一个简单的检查:
x = torch.randn(2, 3)
grad_x = torch.func.grad(lambda x: numpy_sort(x).sum())(x)
assert torch.allclose(grad_x, torch.ones_like(x))
示例 2:autograd.Function
定义自定义梯度规则
另一种常见的情况是实现了一个使用 PyTorch 操作的 torch.autograd.Function
。PyTorch 能够自动计算 PyTorch 操作的梯度,但也许我们希望自定义梯度的计算方式。我们可能希望自定义反向传播与 PyTorch 提供的不同的一些原因包括:
- 改善数值稳定性
- 优化反向传播的性能
- 改进对边缘情况的处理(例如,NaN、Inf)
- 调整梯度(例如,梯度裁剪)
这里是一个示例,展示了如何为函数 y = x ** 3
创建一个 torch.autograd.Function
,我们在其中优化了性能表现(通常在反向传播中计算的 dx
现在在前向传播中完成)。
class MyCube(torch.autograd.Function):
@staticmethod
def forward(x):
result = x ** 3
# 在常规的 PyTorch 中,如果我们刚刚运行了 y = x ** 3,那么反向传播
# 会计算 dx = 3 * x ** 2。在这个 autograd.Function 中,我们
# 在前向传播过程中完成了这个计算。
dx = 3 * x ** 2
return result, dx
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
result, dx = output
ctx.save_for_backward(x, dx)
@staticmethod
def backward(ctx, grad_output, grad_dx):
x, dx = ctx.saved_tensors
# 为了使 autograd.Function 能够处理高阶导数,我们需要加上 `dx` 的梯度贡献。
result = grad_output * dx + grad_dx * 6 * x
return result
现在,为了更方便地使用 NumpySort
(并隐藏我们作为输出返回的中间结果),我们创建了一个新函数来调用它:
def my_cube(x):
result, _ = MyCube.apply(x)
return result
这是一个计算二阶梯度的简单检查:
x = torch.randn([])
ggx = torch.func.grad(torch.func.grad(my_cube))(x)
assert torch.allclose(ggx, 6 * x)
限制和注意事项
警告
请仔细阅读这些关于使用 torch.func
变换时的 torch.autograd.Function
的限制。我们无法捕获许多这些情况并优雅地处理错误,因此它们会导致未定义的行为。
请不要在 torch.autograd.Function
的方法中使用正在被变换的张量、requires_grad=True
的张量或双精度张量。确保在 torch.autograd.Function
的任何方法中使用的张量必须直接作为输入传递(或通过 ctx 对象),而不是从外部传入。
[`torch.autograd.Function`](../autograd.html#torch.autograd.Function) 不处理 pytrees(可能包含或不包含 Tensors 的任意嵌套 Python 数据结构)。为了使这些 Tensors 被 autograd 跟踪,它们必须直接作为参数传递给 [`torch.autograd.Function`](../autograd.html#torch.autograd.Function)。这与 jax.{custom\_vjp, custom\_jvp} 不同,后者确实接受 pytrees。
请仅使用 [`save_for_backward()`](../generated/torch.autograd.function.FunctionCtx.save_for_backward.html#torch.autograd.function.FunctionCtx.save_for_backward) 或 `save_for_forward()` 来保存 Tensors。请不要直接在 ctx 对象上赋值 Tensors 或 Tensors 的集合——这些 Tensors 将不会被跟踪。
## [`torch.vmap()`](../generated/torch.vmap.html#torch.vmap) 支持
要使用 [`torch.autograd.Function`](../autograd.html#torch.autograd.Function) 与 [`torch.vmap()`](../generated/torch.vmap.html#torch.vmap),您必须:
* 提供一个 `vmap()` 静态方法,说明 `torch.autograd.Function` 在 `torch.vmap()` 下的行为。
* 通过设置 `generate_vmap_rule=True` 请求自动生成该规则。
自动生成 vmap 规则
如果您 torch.autograd.Function
满足以下附加约束条件,那么我们可以为其生成一个 vmap 规则。如果不满足这些约束条件或需要在 vmap 下实现自定义行为,请手动定义一个 vmap 静态方法(参见下一节)。
警告
我们无法轻松检查这些约束条件并优雅地报错。违反这些约束条件可能导致未定义的行为。
torch.autograd.Function
的forward()
、backward()
(如果存在)和jvp()
(如果存在)静态方法必须可以通过torch.vmap()
进行转换。也就是说,它们必须仅包含 PyTorch 操作(而不是 NumPy 或自定义 CUDA 内核等)。
示例:
class MyCube(torch.autograd.Function):
# 将 generate_vmap_rule 设置为 True,要求 PyTorch 自动生成 vmap 规则。
generate_vmap_rule = True
@staticmethod
def forward(x):
result = x ** 3
dx = 3 * x ** 2
return result, dx
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
result, dx = output
ctx.save_for_backward(x, dx)
@staticmethod
def backward(ctx, grad_output, grad_dx):
x, dx = ctx.saved_tensors
result = grad_output * dx + grad_dx * 6 * x
return result
def my_cube(x):
result, dx = MyCube.apply(x)
return result
x = torch.randn(3)
result = torch.vmap(my_cube)(x)
assert torch.allclose(result, x ** 3)
定义 vmap 静态方法
如果你的 torch.autograd.Function
调用了另一个系统(如 NumPy、C++、CUDA、triton),那么为了让它与 torch.vmap()
或使用它的转换一起工作,你需要手动定义一个 vmap()
静态方法。
根据您想要使用的变换和具体应用场景,您可能不需要为所有 torch.autograd.Function
添加 vmap()
静态方法:
- 例如,
torch.func.jacrev()
在反向传播过程中应用vmap()
。因此,如果您只对使用torch.func.jacrev()
感兴趣,则只需确保backward()
静态方法能够使用 vmap。
尽管如此,我们还是建议确保所有 torch.autograd.Function
都支持 torch.vmap()
功能,特别是如果您正在编写第三方库,并希望您的 torch.autograd.Function
能与所有 torch.func()
转换组合兼容。
从概念上来说,vmap
静态方法负责定义如何在 torch.vmap()
下执行 forward()
。也就是说,它定义了如何将 forward()
转换为在带有额外维度的输入上运行,其中额外维度是指被 vmap
映射的维度。这与在 PyTorch 操作中实现 torch.vmap()
的方式类似:对于每个操作,我们定义一个 vmap
规则(有时也称为“批量处理规则”)。
以下是定义 vmap()
静态方法的方式:
-
签名是
vmap(info, in_dims: Tuple[Optional[int]], *args)
,其中*args
表示传递给forward()
方法的所有参数。 -
vmap
静态方法定义了在使用torch.vmap()
时,forward()
方法的行为。即,给定带有额外维度(由in_dims
指定)的输入,如何计算forward()
的批量版本? -
对于
args
中的每个参数,in_dims
中都有一个对应的Optional[int]
。如果参数不是 Tensor 或不参与 vmap 运算,则该值为None
;否则,它是一个整数,表示 Tensor 的哪个维度正在被 vmap 运算。 -
info
包含一些额外的元数据,这些元数据可能有助于操作。info.batch_size
表示正在被 vmap 运算的维度的大小,而info.randomness
是传递给torch.vmap()
的randomness
选项。 -
vmap
静态方法返回一个包含(output, out_dims)
的元组。类似于in_dims
,out_dims
应该与output
具有相同的结构,并且每个输出都包含一个out_dim
,用于指定该输出是否有 vmap 维度及其索引。
示例:
注意
vmap
静态方法应确保保留整个 Function
的语义。也就是说,(伪代码)grad(vmap(MyFunc))
应该能够被 grad(map(MyFunc))
替换。
如果你的 autograd.Function
在反向传播中有任何自定义行为,请记住这一点。
注意
对于 PyTorch 可以通过 generate_vmap_rule=True
生成 vmap
规则的 Function
,编写自定义 vmap
静态方法是合法的。如果你希望生成的 vmap
规则不符合你的需求,你可能需要这样做。
torch.func.jvp()
支持
为了支持前向模式自动微分,torch.autograd.Function
必须实现 jvp()
静态方法。详情请参见 前向模式自动微分。