编译自动微分:为 torch.compile
捕获更大的反向图
作者: Simon Fan
你将学到什么
-
编译后的自动求导如何与
torch.compile
交互 -
如何使用编译后的自动求导 API
-
如何使用
TORCH_LOGS
检查日志
前提条件
-
PyTorch 2.4
-
阅读 PyTorch 2.x 入门指南 中的 TorchDynamo 和 AOTAutograd 部分
概述
编译自动求导(Compiled Autograd)是 PyTorch 2.4 中引入的一个 torch.compile
扩展,它能够捕获更大的反向计算图。
虽然 torch.compile
确实会捕获反向计算图,但它是部分捕获的。AOTAutograd 组件会提前捕获反向计算图,但存在一些限制:
-
前向传播中的图断裂会导致反向传播中的图断裂
-
反向钩子未被捕获
编译后的自动梯度(Compiled Autograd)通过直接与自动梯度引擎集成,解决了这些限制,使其能够在运行时捕获完整的反向计算图。具有这两个特性的模型可以尝试使用编译后的自动梯度,并有可能观察到更好的性能。
然而,编译后的自动梯度也引入了自身的限制:
-
在反向传播开始时增加了缓存查找的运行时开销
-
由于捕获范围更大,在 dynamo 中更容易触发重编译和图中断
编译式自动求导功能正在积极开发中,目前尚不完全兼容所有现有的 PyTorch 功能。有关特定功能的最新状态,请参考 编译式自动求导功能主页。
配置
在本教程中,我们将基于这个简单的神经网络模型来进行示例。该模型接收一个10维的输入向量,通过一个线性层进行处理,并输出另一个10维的向量。
importtorch
classModel(torch.nn.Module):
def__init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
defforward(self, x):
return self.linear(x)
基本用法
在调用 torch.compile
API 之前,请确保将 torch._dynamo.config.compiled_autograd
设置为 True
:
model = Model()
x = torch.randn(10)
torch._dynamo.config.compiled_autograd = True
@torch.compile
deftrain(model, x):
loss = model(x).sum()
loss.backward()
train(model, x)
在上面的代码中,我们创建了一个 Model
类的实例,并通过 torch.randn(10)
生成了一个随机的 10 维张量 x
。我们定义了训练循环函数 train
,并用 @torch.compile
装饰它以优化其执行。当调用 train(model, x)
时:
-
Python 解释器调用 Dynamo,因为该调用被
@torch.compile
装饰。 -
Dynamo 拦截 Python 字节码,模拟其执行并将操作记录到图中。
-
AOTDispatcher
禁用钩子并调用 autograd 引擎来计算model.linear.weight
和model.linear.bias
的梯度,并将操作记录到图中。使用torch.autograd.Function
,AOTDispatcher
重写了train
的前向和反向实现。 -
Inductor 生成一个函数,对应于
AOTDispatcher
前向和反向的优化实现。 -
Dynamo 设置优化后的函数供 Python 解释器接下来执行。
-
Python 解释器执行优化后的函数,该函数执行
loss = model(x).sum()
。 -
Python 解释器执行
loss.backward()
,调用 autograd 引擎,由于我们设置了torch._dynamo.config.compiled_autograd = True
,因此会路由到 Compiled Autograd 引擎。 -
Compiled Autograd 计算
model.linear.weight
和model.linear.bias
的梯度,并将操作记录到图中,包括它遇到的任何钩子。在此过程中,它将记录之前由AOTDispatcher
重写的反向实现。Compiled Autograd 然后生成一个新函数,该函数对应于loss.backward()
的完全追踪实现,并在推理模式下使用torch.compile
执行它。 -
相同的步骤递归地应用于 Compiled Autograd 图,但这次
AOTDispatcher
不需要对图进行分区。
检查已编译的自动求导日志
使用 TORCH_LOGS
环境变量运行脚本:
-
仅打印编译后的 autograd 图,请使用
TORCH_LOGS="compiled_autograd" python example.py
-
要打印包含更多张量元数据和重新编译原因的图(但会牺牲性能),请使用
TORCH_LOGS="compiled_autograd_verbose" python example.py
重新运行上面的代码片段,编译后的自动微分图现在应该会被记录到 stderr
中。某些图节点会带有以 aot0_
为前缀的名称,这些节点对应于之前在 AOTAutograd 反向图 0 中提前编译的节点。例如,aot0_view_2
对应于 id=0 的 AOT 反向图中的 view_2
。
在下图中,红色框包围的是 torch.compile
在没有启用 Compiled Autograd 的情况下捕获的 AOT 反向图。
这是我们将调用
torch.compile
的图,而非优化后的图。Compiled Autograd 本质上会生成一些未优化的 Python 代码来表示整个 C++ autograd 的执行过程。
使用不同的标志编译前向和后向传递
您可以为两次编译使用不同的编译器配置,例如,即使在前向过程中存在图中断,后向过程也可能是一个完整的图。
deftrain(model, x):
model = torch.compile(model)
loss = model(x).sum()
torch._dynamo.config.compiled_autograd = True
torch.compile(lambda: loss.backward(), fullgraph=True)()
或者您可以使用上下文管理器,它将应用于其作用域内的所有自动求导调用。
deftrain(model, x):
model = torch.compile(model)
loss = model(x).sum()
with torch._dynamo.compiled_autograd.enable(torch.compile(fullgraph=True)):
loss.backward()
Compiled Autograd 解决了 AOTAutograd 的某些限制
- 前向传播中的图中断不再必然导致反向传播中的图中断:
@torch.compile(backend="aot_eager")
deffn(x):
# 1st graph
temp = x + 10
torch._dynamo.graph_break()
# 2nd graph
temp = temp + 10
torch._dynamo.graph_break()
# 3rd graph
return temp.sum()
x = torch.randn(10, 10, requires_grad=True)
torch._dynamo.utils.counters.clear()
loss = fn(x)
# 1. base torch.compile
loss.backward(retain_graph=True)
assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 3)
torch._dynamo.utils.counters.clear()
# 2. torch.compile with compiled autograd
with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):
loss.backward()
# single graph for the backward
assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 1)
在第一个 torch.compile
案例中,由于编译函数 fn
中存在 2 个图中断,我们看到了 3 个反向图被生成。而在第二个使用了编译自动求导的 torch.compile
案例中,尽管存在图中断,我们仍看到了一个完整的反向图被追踪生成。
在追踪由 Compiled Autograd 捕获的反向钩子时,Dynamo 仍有可能出现图中断的情况。
- 现在可以捕获反向钩子
@torch.compile(backend="aot_eager")
deffn(x):
return x.sum()
x = torch.randn(10, 10, requires_grad=True)
x.register_hook(lambda grad: grad+10)
loss = fn(x)
with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):
loss.backward()
图中应该有一个 call_hook
节点,Dynamo 稍后会将其内联为以下内容:
Compiled Autograd 的常见重新编译原因
- 由于损失值的自动求导结构发生变化:
torch._dynamo.config.compiled_autograd = True
x = torch.randn(10, requires_grad=True)
for op in [torch.add, torch.sub, torch.mul, torch.div]:
loss = op(x, x).sum()
torch.compile(lambda: loss.backward(), backend="eager")()
在上面的示例中,我们在每次迭代中调用不同的运算符,导致 loss
每次跟踪不同的 autograd 历史记录。您应该会看到一些重新编译的消息:由于新的 autograd 节点导致缓存未命中。
- 由于张量形状的变化:
torch._dynamo.config.compiled_autograd = True
for i in [10, 100, 10]:
x = torch.randn(i, i, requires_grad=True)
loss = x.sum()
torch.compile(lambda: loss.backward(), backend="eager")()
在上面的示例中,x
的形状发生了变化,编译后的 autograd 会在第一次变化后将 x
标记为动态形状张量。您应该会看到重新编译的消息:由于形状变化导致的缓存未命中。
结论
在本教程中,我们概述了torch.compile
的高级生态系统及其编译自动求导功能,介绍了编译自动求导的基础知识,以及一些常见的重新编译原因。敬请期待我们在dev-discuss上的深入探讨。