常见问题
作者: Mark Saroufim
支持使用 torch.compile
进行训练吗?
torch.compile
支持训练,并使用 AOTAutograd 捕捉反向过程:
-
.forward()
图和optimizer.step()
被 TorchDynamo 的 Pythonevalframe
前端捕获。 -
对于torchdynamo捕获的每个
.forward()
段,它使用AOTAutograd生成相应的反向传播图段。 -
每一对正向和反向图(可选地)通过最小割划分来保存它们之间最少的状态信息。
-
前向和后向的配对操作被封装在
autograd.function
模块中。 -
当 Usercode 调用
.backward()
时,仍然会触发 eager 的自动微分引擎。该引擎将每个 编译后的反向 图作为单个操作来运行,并且还会执行任何未编译的 eager 操作的.backward()
函数。
支持分布式代码吗?
torch.compile
支持 DistributedDataParallel
(DDP)。其他分布式训练库的支持也在考虑之中。
分布式代码在 Dynamo 中具有挑战性的主要原因在于 AOTAutograd 同时展开了前向和后向传递,并为后端提供了两个图以供优化。这对于分布式代码来说是一个问题,因为我们希望理想地将通信操作与计算重叠起来。Eager PyTorch 通过不同的方式来实现这一点,例如 DDP/FSDP 使用 autograd 挂钩、模块挂钩以及对模块状态的修改/变异。在 Dynamo 的简单应用中,原本应在后向传递的操作之后立即运行的挂钩可能会被延迟到整个编译后的后向操作区域之后,这是由于 AOTAutograd 编译函数与调度器挂钩之间的交互方式。
使用 Dynamo 优化 DDP 的基本策略概述在 distributed.py 中,主要思想是在 DDP 桶边界处进行图拆分。
当分布式数据并行(DDP)中的每个节点需要与其他节点同步权重时,它会将梯度和参数组织到桶中。这样可以减少通信时间,并使一个节点能够将其部分梯度广播给其他等待的节点。
分布式代码中的图中断意味着你可以期望 dynamo 及其后端优化分布式程序的计算开销,但不会优化通信开销。图中断可能会干扰编译加速,如果减少的图大小限制了编译器的数据融合机会。然而,随着图规模的增加,收益递减,因为目前大多数计算优化都是局部融合。因此,在实践中这种方法可能是足够的。
是否还需要导出整个图形?
对于大多数模型,你可能不需要这样做,并且可以像这样使用torch.compile()
。但在某些情况下需要完整的图,你可以通过简单地运行torch.compile(..., fullgraph=True)
来确保生成完整图。这些情况包括:
-
大型训练任务,例如需要管道并行和其它高级分片策略的$250K以上的任务。
-
像TensorRT 或 AITemplate 这样的推理优化器,依赖于比训练优化器更激进的融合策略。
-
在移动设备上进行训练或推理。
未来的 work 将包括将通信操作记录为图形,协调这些操作与计算优化,并优化通信操作。
我的代码为何会崩溃?
如果你的代码在没有使用 torch.compile
时运行良好,但在启用它后开始崩溃,则最重要的第一步是确定失败发生在堆栈的哪个部分。为了排查这个问题,请按照以下步骤操作,并且只在前一步成功的情况下继续下一步。
-
torch.compile(..., backend="eager")
仅执行 TorchDynamo 的前向图捕获,然后使用 PyTorch 执行捕获的图。如果此操作失败,则表明 TorchDynamo 存在问题。 -
torch.compile(..., backend="aot_eager")
使用 TorchDynamo 捕获前向图,然后使用 AOTAutograd 追踪后向图而不进行额外的编译步骤。接着用 PyTorch eager 执行前向和后向图。如果执行失败,则说明 AOTAutograd 存在问题。 -
torch.compile(..., backend="inductor")
这行代码会使用 TorchDynamo 捕获前向图,然后通过 AOTAutograd 和 TorchInductor 编译器追踪反向图。如果这一步失败了,则说明存在与 TorchInductor 相关的问题。
为何编译速度慢?
-
Dynamo 编译– TorchDynamo 内置了一个统计功能,用于收集和显示每个编译阶段的时间。可以通过在执行
torch._dynamo
之后调用torch._dynamo.utils.compile_times()
来访问这些统计数据。默认情况下,这会返回一个字符串形式的报告,其中包含按名称统计的每个 TorchDynamo 函数所花费的编译时间。 -
电感器编译– TorchInductor 内置了统计和跟踪功能,用于显示每个编译阶段所花费的时间、生成的代码、图可视化以及 IR 转储。可以通过执行
env TORCH_COMPILE_DEBUG=1 python repro.py
来启用此调试工具,其输出类似于这个示例。每个调试跟踪文件都可以通过torch._inductor.config.trace.*
进行启用或禁用。由于生成这些信息的成本较高,默认情况下会禁用性能分析和图示功能。更多示例请参见调试目录输出示例。 -
过度重新编译 当 TorchDynamo 编译一个函数或其部分时,它会对局部变量和全局变量做出某些假设以允许编译优化。这些假设被表达为在运行时检查特定值的保护条件。如果任何保护条件失败,Dynamo 将最多重新编译该函数(或其部分)
torch._dynamo.config.cache_size_limit
次。如果你的程序达到了缓存限制,你需要确定哪个保护条件失败了以及你的程序中的哪一部分触发了它。重新编译分析器 自动将 TorchDynamo 的缓存限制设置为 1,并在仅观察模式下运行你的程序以记录任何保护条件失败的原因。你应该确保至少以与遇到问题时相同的持续时间(迭代次数)来运行你的程序,分析器将在整个期间累积统计数据。
为何在生产环境中进行重新编译?
在某些情况下,程序预热后可能不希望出现意外的编译,例如在延迟敏感的应用程序中处理生产流量时。为此,TorchDynamo 提供了一种替代模式,在这种模式下使用先前编译的图,但不再生成新的图:
frozen_toy_example = dynamo.run(toy_example) frozen_toy_example(torch.randn(10), torch.randn(10))
你是如何加快我的代码运行速度的?
有三种主要方式可以加速 PyTorch 代码:
-
通过垂直融合进行内核融合,将连续的操作合并以减少不必要的读取和写入。例如,将两个连续的余弦运算合并后,可以只进行一次读取和一次写入,而不是两次读取和两次写入。水平融合的一个简单例子是批处理,在批处理中一个矩阵与多个示例同时相乘。更一般的情况是一组矩阵乘法操作被一起调度执行。
-
乱序执行:编译器的一种通用优化技术,通过提前分析图中的数据依赖关系,我们可以确定最佳的执行时机,并判断哪些缓冲区可以被重用。
-
自动工作分配:类似于乱序执行的概念,但通过将图中的节点与物理硬件或内存等资源进行匹配,我们可以设计出合适的调度方案。
以上是一些加速 PyTorch 代码的通用原则,但不同的后端会在优化方面做出不同的权衡。例如,Inductor 首先会尽可能地融合操作,然后才生成Triton内核。
Triton还因为自动内存合并、内存管理以及每个流式多处理器内部的调度,提供了性能加速,并且它专门设计用于处理分块计算。
然而,无论你使用哪种后端,最好都进行基准测试并亲自观察。你可以尝试使用 PyTorch 分析器,通过可视化检查生成的内核来理解具体发生了什么。
为什么我没有观察到加速效果?
图表中断
你使用 Dynamo 没有看到预期的加速效果,主要原因是过多的图断开。那么,什么是图断开呢?
def some_fun(x): ... torch.compile(some_fun)(x) ...
Torchdynamo 尝试将 some_fun()
中的所有 torch/tensor 操作编译成一个单一的 FX 图,但可能会失败,无法将所有操作都包含在一个图中。
一些图中断的原因对于TorchDynamo来说是无法克服的,例如调用除PyTorch以外的C扩展对TorchDynamo来说是不可见的。这些扩展可能在没有给TorchDynamo引入必要的保护措施的情况下执行任意操作,导致编译后的程序不安全地重用。
为了最大化性能,应该尽可能减少图表中的断点。
确定图中断的原因
要识别程序中所有图中断及其原因,可以使用 torch._dynamo.explain
工具。该工具会在提供的函数上运行TorchDynamo,并汇总遇到的图中断情况。以下是一个示例用法:
import torch import torch._dynamo as dynamo def toy_example(a, b): x = a / (torch.abs(a) + 1) print("woo") if b.sum() < 0: b = b * -1 return x * b explanation = dynamo.explain(toy_example)(torch.randn(10), torch.randn(10)) print(explanation) """ Graph Count: 3 Graph Break Count: 2 Op Count: 5 Break Reasons: Break Reason 1: Reason: builtin: print [<class 'torch._dynamo.variables.constant.ConstantVariable'>] False User Stack: <FrameSummary file foo.py, line 5 in toy_example> Break Reason 2: Reason: generic_jump TensorVariable() User Stack: <FrameSummary file foo.py, line 6 in torch_dynamo_resume_in_toy_example_at_5> Ops per Graph: ... Out Guards: ... """
要在遇到第一个图形中断时抛出错误,你可以通过设置 fullgraph=True
来禁用 Python 回退功能。如果你之前使用过基于导出的编译器,这一点应该很熟悉。
def toy_example(a, b): ... torch.compile(toy_example, fullgraph=True, backend=<compiler>)(a, b)
为什么我修改代码后没有触发重新编译?
如果你通过设置 env TORCHDYNAMO_DYNAMIC_SHAPES=1 python model.py
启用了动态形状功能,那么当输入的形状发生变化时,你的代码不会重新编译。我们添加了对动态形状的支持,在这种情况下,如果形状的变化小于两倍,则可以避免重新编译。这在诸如计算机视觉中图像大小变化或自然语言处理中序列长度可变的情况下特别有用。在推理场景中,由于你从不同的客户端应用程序获取数据,通常无法提前确定批次的大小。
通常,TorchDynamo 努力避免不必要的重新编译。例如,如果 TorchDynamo 找到 3 个图而你的更改只修改了一个图,则只会重新编译该图。因此,另一个技巧是先预热模型(即编译一次),这样后续的编译会快得多。冷启动编译时间是我们显式跟踪的一个指标。
为什么会得到错误的结果?
通过设置环境变量 TORCHDYNAMO_REPRO_LEVEL=4
,可以减少准确性问题。它采用类似 git bisect 的模型,并且完全重现可能类似于 TORCHDYNAMO_REPRO_AFTER="aot" TORCHDYNAMO_REPRO_LEVEL=4
。我们需要这个的原因是下游编译器会生成代码(无论是 Triton 代码还是 C++ 后端代码),这些编译器在细微之处的数值可能有所不同,但会对训练稳定性产生重大影响。因此,准确性调试器对于我们检测代码生成或后端编译器中的错误非常有用。
如果你希望确保 torch 和 triton 中的随机数生成一致,可以将 torch._inductor.config.fallback_random = True
开启。
为什么会收到内存不足的错误?
Dynamo 仍然是一个 alpha 版本,因此可能还会遇到一些导致内存溢出(OOM)的问题。如果你遇到了内存溢出,请按以下顺序禁用相应的配置,并在 GitHub 上提交一个问题以便我们解决根本原因:1. 如果你在使用动态形状,请尝试将其关闭,默认情况下它们已经被禁用了:env TORCHDYNAMO_DYNAMIC_SHAPES=0 python model.py
2. 默认情况下,CUDA 图形与 Triton 在 inductor 中启用,但移除它们可能会缓解一些内存溢出问题:torch._inductor.config.triton.cudagraphs = False
.
torch.func
是否可以与 torch.compile
一起使用(特别是在 grad 和 vmap 变换中)?
可以将使用了torch.compile
的函数转换为torch.func
:
import torch @torch.compile def f(x): return torch.sin(x) def g(x): return torch.grad(f)(x) x = torch.randn(2, 3) g(x)
在使用torch.compile
处理的函数内部调用torch.func
变换
用torch.compile
编译torch.func.grad
import torch def wrapper_fn(x): return torch.func.grad(lambda x: x.sin().sum())(x) x = torch.randn(3, 3, 3) grad_x = torch.compile(wrapper_fn)(x)
用torch.compile
编译torch.vmap
import torch def my_fn(x): return torch.vmap(lambda x: x.sum(1))(x) x = torch.randn(3, 3, 3) output = torch.compile(my_fn)(x)
编译不受支持的函数(作为逃生机制)
对于其他变换,可以使用 torch._dynamo.allow_in_graph
作为替代方案
allow_in_graph
是一个逃生机制。如果你的代码与 torch.compile
不兼容,而 torch.compile
会分析 Python 字节码,但你认为通过符号跟踪方法(如 jax.jit
)可以解决问题,则使用 allow_in_graph
。
通过在函数上使用allow_in_graph
注解,你必须确保你的代码满足以下要求:
-
你的函数中所有的输出仅取决于输入,不依赖于任何被捕获的张量。
-
你的函数是纯函数式的,即不修改任何状态。这一要求可以适当放宽;我们实际支持的是那些从外部看起来像纯函数的函数:它们可以包含就地执行的 PyTorch 操作,但不能修改全局状态或函数输入。
-
你的函数不会因数据而引发错误。
import torch @torch.compile def f(x): return torch._dynamo.allow_in_graph(torch.vmap(torch.sum))(x) x = torch.randn(2, 3) f(x)
一个常见的陷阱是在调用 nn.Module
的函数上使用 allow_in_graph
进行注解。这是因为输出现在依赖于 nn.Module
的参数。为了使这能正常工作,应使用 torch.func.functional_call
来提取模块状态。
NumPy 是否支持 torch.compile
?
从版本 2.1 开始,torch.compile
能够理解运行在 NumPy 数组上的原生 NumPy 程序,以及那些通过 x.numpy()
、torch.from_numpy
和相关函数在 PyTorch 和 NumPy 之间进行转换的混合程序。
torch.compile 支持哪些 NumPy 功能?
在torch.compile
中使用的NumPy遵循NumPy 2.0的预发布版本。
通常,torch.compile
可以追踪大多数的 NumPy 构造。当无法追踪时,它会退回到即时执行模式,并让 NumPy 执行那段代码。即使如此,仍然有一些特性使得 torch.compile
的行为与 NumPy 稍微不同:
-
NumPy 标量:我们将它们建模为0维数组。例如,
np.float32(3)
在torch.compile
下返回一个 0 维数组。为了避免图中断,最好使用这个 0 维数组。如果代码因此出现问题,可以通过将 NumPy 标量转换为相应的 Python 标量类型(如bool/int/float
)来解决。 -
负步长:使用
np.flip
以及带有负步长的切片操作会返回一个新的副本。 -
类型提升:NumPy 的类型提升规则在 NumPy 2.0 中将发生变化。新的规则可以在NEP 50中找到。
torch.compile
实现了 NEP 50,而不是即将被弃用的当前规则。 -
{tril,triu}_indices_from/{tril,triu}_indices
返回数组,而不是数组的元组。
对于某些我们不支持追踪的功能,我们将优雅地回退到使用 NumPy 进行执行。
-
如日期时间、字符串、字符、void、结构化数据类型和 recarrays 等非数值数据类型。
-
长整型数据类型
np.float128/np.complex256
以及一些无符号数据类型np.uint16/np.uint32/np.uint64
。 -
ndarray
子类。 -
带掩码的数组。
-
类似于
axes=[(n,k),(k,m)->(n,m)]
的esoteric ufunc机制以及ufunc方法(例如np.add.reduce
)。 -
对
complex64/complex128
数组进行排序。 -
NumPy 的
np.poly1d
和np.polynomial
。 -
当函数返回两个或更多的值时,可以使用位置参数
out1, out2
(或者使用out=tuple
也是可以的)。 -
__array_function__
,__array_interface__
和__array_wrap__
。 -
ndarray.ctypes
属性。
可以使用torch.compile
来编译NumPy代码吗?
当然可以!torch.compile
可以直接理解并处理 NumPy 代码,就像它是 PyTorch 代码一样。为此,只需用 torch.compile
装饰器来包裹你的 NumPy 代码。
import torch import numpy as np @torch.compile def numpy_fn(X: np.ndarray, Y: np.ndarray) -> np.ndarray: return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1)) X = np.random.randn(1024, 64) Y = np.random.randn(1024, 64) Z = numpy_fn(X, Y) assert isinstance(Z, np.ndarray)
在设置环境变量 TORCH_LOGS=output_code
后运行此示例,我们可以看到 torch.compile
能够将乘法和求和操作融合为一个 C++ 内核,并且能够使用 OpenMP 技术并行执行这些操作(而原生 NumPy 是单线程的)。这可以使你的 NumPy 代码加速多达 n
倍,其中 n
表示你处理器的核心数。
这种方式跟踪 NumPy 代码,还支持在编译代码中出现的图中断。
是否可以在CUDA上使用NumPy代码,并通过torch.compile
计算梯度?
当然可以!你可以在一个 torch.device("cuda")
上下文中运行你的代码。下面是一个示例:
import torch import numpy as np @torch.compile def numpy_fn(X: np.ndarray, Y: np.ndarray) -> np.ndarray: return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1)) X = np.random.randn(1024, 64) Y = np.random.randn(1024, 64) with torch.device("cuda"): Z = numpy_fn(X, Y) assert isinstance(Z, np.ndarray)
在这个例子中,numpy_fn
将在 CUDA 上执行。为了实现这一点,torch.compile
会自动将 X
和 Y
从 CPU 移动到 CUDA,并将结果 Z
再次移动回 CPU。如果我们在这个程序运行中多次执行此函数,我们可能希望避免这些昂贵的内存复制操作。为此,我们可以调整 numpy_fn
使其接受和返回 cuda 张量,并使用 torch.compiler.wrap_numpy
来实现这一点:
@torch.compile(fullgraph=True) @torch.compiler.wrap_numpy def numpy_fn(X, Y): return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1)) X = torch.randn(1024, 64, device="cuda") Y = torch.randn(1024, 64, device="cuda") Z = numpy_fn(X, Y) assert isinstance(Z, torch.Tensor) assert Z.device.type == "cuda"
在这里,我们显式地在CUDA内存中创建张量,并将它们传递给函数,在该函数中所有计算都在CUDA设备上完成。wrap_numpy
负责标记任何 torch.Tensor
输入为具有 np.ndarray
语义的输入,这是在 torch.compile
层级上进行的。在编译器内部对张量进行标记是一个非常便宜的操作,因此在运行时不会发生数据复制或移动。
使用这个装饰器,我们可以对 NumPy 代码进行求导。
@torch.compile(fullgraph=True) @torch.compiler.wrap_numpy def numpy_fn(X, Y): return np.mean(np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))) X = torch.randn(1024, 64, device="cuda", requires_grad=True) Y = torch.randn(1024, 64, device="cuda") Z = numpy_fn(X, Y) assert isinstance(Z, torch.Tensor) Z.backward() # X.grad now holds the gradient of the computation print(X.grad)
我们一直在使用fullgraph=True
,因为在当前情况下图中断会导致问题。当发生图中断时,我们需要将 NumPy 数组进行实体化。由于 NumPy 数组没有device
或requires_grad
的概念,在图中断期间这些信息会丢失。
我们无法通过图中断来传播梯度,因为图中断代码可能执行任意代码,这些代码不知道如何进行求导。另一方面,在 CUDA 执行的情况下,我们可以像在第一个示例中那样使用 torch.device("cuda")
上下文管理器来解决这个问题:
@torch.compile @torch.compiler.wrap_numpy def numpy_fn(X, Y): prod = X[:, :, None] * Y[:, None, :] print("oops, a graph break!") return np.sum(prod, axis=(-2, -1)) X = torch.randn(1024, 64, device="cuda") Y = torch.randn(1024, 64, device="cuda") with torch.device("cuda"): Z = numpy_fn(X, Y) assert isinstance(Z, torch.Tensor) assert Z.device.type == "cuda"
在图中断期间,中间张量仍然需要移动到CPU上。但在恢复追踪后,剩余的图部分仍在CUDA上进行追踪。由于这种CUDA与CPU之间的频繁切换,在NumPy上下文中,图中断的成本较高,应尽量避免。不过,至少它们允许通过复杂的代码段进行追踪。
如何在 torch.compile
下调试 NumPy 代码?
调试即时编译的代码颇具挑战性,考虑到现代编译器的复杂性和它们引发的令人难以处理的错误。如何在 torch.compile 中诊断运行时错误的教程提供了一些应对这一任务的小技巧。
如果上述方法仍无法确定问题的来源,我们还可以使用一些特定于NumPy的工具进行进一步排查。通过禁用NumPy函数中的跟踪功能,我们可以判断错误是否完全由PyTorch代码引起。
from torch._dynamo import config config.trace_numpy = False
如果错误出现在被追踪的NumPy代码中,我们可以使用PyTorch作为后端(不使用torch.compile
),通过导入import torch._numpy as np
来急切地执行NumPy代码。这仅用于调试目的,绝不是PyTorch API的替代品,因为它的性能要差得多,并且作为私有API,可能会在没有通知的情况下更改。无论如何,torch._numpy
是基于PyTorch实现的NumPy的Python版本,它被torch.compile
内部使用来将NumPy代码转换为PyTorch代码。它的读取和修改都非常容易,因此如果你发现其中有任何错误,请随时提交修复该问题的PR或直接打开一个issue。
如果导入 torch._numpy as np
后程序可以正常运行,那么问题可能出在 TorchDynamo 中。如果是这种情况,请提交一个包含最小复现代码的问题。
我使用 torch.compile
编译了一些 NumPy 代码,但没有看到任何速度提升。
最佳起点是参考这个教程,它提供了关于如何调试此类 torch.compile 问题的一般性建议。
一些图中断可能是由于使用了不受支持的功能导致的。参见torch.compile 支持哪些 NumPy 功能?。通常需要注意的是,某些广泛使用的 NumPy 功能与编译器不兼容。例如,就地修改使得在编译器中进行推理变得困难,并且性能往往不如非就地操作好。out=
参数同样需要避免使用。建议使用非就地操作,并让 torch.compile
优化内存使用。此外,数据依赖的操作(如通过布尔掩码进行的索引)和数据依赖的控制流(例如 if
或 while
构造)也需要特别注意。
应该使用哪个 API 进行细粒度跟踪?
在某些情况下,你可能需要将代码中的某些小部分从 torch.compile 编译中排除。本节提供了一些解决方案,更多相关信息可以在 TorchDynamo APIs for fine-grained tracing 中找到。
如何在函数上设置断点?
仅仅在函数上设置断点是不够的,无法充分表达你希望 PyTorch 做什么。你需要更详细地描述你的使用场景。以下是一些常见的使用场景供你参考:
-
如果你想在这项功能框架及其所有递归调用中禁用编译,可以使用
torch._dynamo.disable
。 -
如果你想让特定的操作符(如
fbgemm
)使用即时执行模式,可以使用torch._dynamo.disallow_in_graph
。
一些不太常见的用例包括:
-
如果你想在函数帧中禁用TorchDynamo,但对递归调用的帧重新启用它,请使用
torch._dynamo.disable(recursive=False)
。 -
如果你想防止某个函数帧的内联,可以在该函数的开头使用
torch._dynamo.graph_break
。
什么是 torch._dynamo.disable
和 torch._dynamo.disallow_in_graph
之间的区别
Disallow-in-graph 在操作符级别起作用,更具体地说,是指在 TorchDynamo 提取的图中可见的操作符。
Disable 作用于函数级别,决定 TorchDynamo 是否应分析该函数。
什么是 torch._dynamo.disable
和 torch._dynamo_skip
之间的区别
注意
torch._dynamo_skip
已被弃用。
你很可能需要使用 torch._dynamo.disable
。但在极少数情况下,你可能需要更细粒度的控制。假设你想只在 a_fn
函数中禁用跟踪,但希望继续在 aa_fn
和 ab_fn
中进行跟踪。下面的图片演示了这种情况:

在这种情况下,你可以使用 torch._dynamo.disable(recursive=False)
。在之前的版本中,这项功能是由 torch._dynamo.skip
提供的。现在可以通过设置 torch._dynamo.disable
中的 recursive
标志来实现。