自定义后端
概述
torch.compile
提供了一种简便的方式来让用户定义自定义后端。
一个后端函数的契约是 (gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]) -> Callable
。
后端函数可以被 TorchDynamo 调用,它是 torch.compile
的图跟踪组件,在对 FX 图进行跟踪之后调用。这些函数需要返回一个与所跟踪的 FX 图等价的编译后的函数。返回的对象应该遵循原始 torch.fx.GraphModule
中 forward
函数的契约: (*args: torch.Tensor) -> List[torch.Tensor]
。
为了使 TorchDynamo 能调用你的后端函数,请在调用 torch.compile
时将该函数作为 backend
参数传递。例如:
import torch def my_custom_backend(gm, example_inputs): return gm.forward def f(...): ... f_opt = torch.compile(f, backend=my_custom_backend) @torch.compile(backend=my_custom_backend) def g(...): ...
请参见下方的更多示例。
注册自定义后端
你可以使用 register_backend
装饰器来注册你的后端,例如:
from torch._dynamo import register_backend @register_backend def my_compiler(gm, example_inputs): ...
除了register_backend
装饰器之外,如果你的后端位于另一个Python包中,还可以通过该包的入口点来注册后端。这种方式允许一个包为另一个包注册插件。
提示
你可以查阅Python 包装文档,了解更多关于entry_points
的内容。
要通过 entry_points
注册你的后端,你可以在包的 setup.py
文件中将后端函数添加到 torch_dynamo_backends
入口点组,如下所示:
... setup( ... 'torch_dynamo_backends': [ 'my_compiler = your_module.submodule:my_compiler', ] ... )
请将my_compiler
替换为你的后端名称,并在等号之后的部分替换为你后端函数的模块和函数名。安装包后,入口点会被添加到你的Python环境中。当你调用torch.compile(model, backend="my_compiler")
时,PyTorch会首先查找通过register_backend
注册的名为my_compiler
的后端。如果没有找到,则会在所有通过entry_points
注册的后端中继续查找。
注册有两大用途:
-
你可以将包含后端函数名称的字符串传递给
torch.compile
,而不需要传递函数本身,例如:torch.compile(model, backend="my_compiler")
。 -
它在使用minifier时是必需的。任何由 minifier 生成的代码都必须通过一个
import
语句来调用你注册后端函数的代码。
AOTAutograd 之后的自定义后端
可以为AOTAutograd定义自定义后端(而非TorchDynamo),这样做主要有两个原因:
-
用户可以定义支持模型训练的后端,因为AOTAutograd可以为编译生成反向图。
-
AOTAutograd 生成由 核心 Aten 操作 组成的 FX 图。因此,自定义后端只需支持核心 Aten 操作集,这比整个 torch/Aten 操作集要小得多。
用torch._dynamo.backends.common.aot_autograd
包装你的后端,并像之前一样使用torch.compile
和backend
关键字参数。由aot_autograd
包装的后端函数应遵循之前的约定。
后端函数通过 aot_autograd
的 fw_compiler
(前向编译器)或 bw_compiler
(反向编译器)关键字参数传递。如果没有指定 bw_compiler
,则默认使用 fw_compiler
作为反向编译函数。
需要注意的是,AOTAutograd要求后端返回的编译函数必须是“带包装的”。可以使用functorch.compile.make_boxed_func
来实现这一点。
例如:
from torch._dynamo.backends.common import aot_autograd from functorch.compile import make_boxed_func def my_compiler(gm, example_inputs): return make_boxed_func(gm.forward) my_backend = aot_autograd(fw_compiler=my_compiler) # bw_compiler=my_compiler model_opt = torch.compile(model, backend=my_backend)
例子
调试后端
如果你想更好地理解编译过程,可以创建一个自定义编译器(本节称为后端),它会漂亮地打印出从Dynamo字节码分析中提取的fx GraphModule
,并返回一个forward()
可调用对象。
例如:
from typing import List import torch def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): print("my_compiler() called with FX graph:") gm.graph.print_tabular() return gm.forward # return a python callable @torch.compile(backend=my_compiler) def fn(x, y): a = torch.cos(x) b = torch.sin(y) return a + b fn(torch.randn(10), torch.randn(10))
运行上述示例会生成以下输出:
my_compiler() called with FX graph: opcode name target args kwargs ------------- ------ ------------------------------------------------------ ---------- -------- placeholder x x () {} placeholder y y () {} call_function cos <built-in method cos of type object at 0x7f1a894649a8> (x,) {} call_function sin <built-in method sin of type object at 0x7f1a894649a8> (y,) {} call_function add <built-in function add> (cos, sin) {} output output output ((add,),) {}
这同样适用于torch.nn.Module
,如下所示:
from typing import List import torch def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): print("my_compiler() called with FX graph:") gm.graph.print_tabular() return gm.forward # return a python callable class MockModule(torch.nn.Module): def __init__(self): super().__init__() self.relu = torch.nn.ReLU() def forward(self, x): return self.relu(torch.cos(x)) mod = MockModule() optimized_mod = torch.compile(mod, backend=my_compiler) optimized_mod(torch.randn(10))
让我们再看一个关于控制流的例子:
from typing import List import torch def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): print("my_compiler() called with FX graph:") gm.graph.print_tabular() return gm.forward # return a python callable @torch.compile(backend=my_compiler) def toy_example(a, b): x = a / (torch.abs(a) + 1) if b.sum() < 0: b = b * -1 return x * b for _ in range(100): toy_example(torch.randn(10), torch.randn(10))
运行此示例会生成以下输出:
my_compiler() called with FX graph: opcode name target args kwargs ------------- ------- ------------------------------------------------------ ---------------- -------- placeholder a a () {} placeholder b b () {} call_function abs_1 <built-in method abs of type object at 0x7f8d259298a0> (a,) {} call_function add <built-in function add> (abs_1, 1) {} call_function truediv <built-in function truediv> (a, add) {} call_method sum_1 sum (b,) {} call_function lt <built-in function lt> (sum_1, 0) {} output output output ((truediv, lt),) {} my_compiler() called with FX graph: opcode name target args kwargs ------------- ------ ----------------------- ----------- -------- placeholder b b () {} placeholder x x () {} call_function mul <built-in function mul> (b, -1) {} call_function mul_1 <built-in function mul> (x, mul) {} output output output ((mul_1,),) {} my_compiler() called with FX graph: opcode name target args kwargs ------------- ------ ----------------------- --------- -------- placeholder b b () {} placeholder x x () {} call_function mul <built-in function mul> (x, b) {} output output output ((mul,),) {} The order of the last two graphs is nondeterministic depending on which one is encountered first by the just-in-time compiler.
快速后台
集成一个提供更好性能的自定义后端非常简单,我们将使用 optimize_for_inference 来整合一个真实的后端。
def optimize_for_inference_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): scripted = torch.jit.script(gm) return torch.jit.optimize_for_inference(scripted)
然后你应该能够使用以下方法优化任何现有代码:
@torch.compile(backend=optimize_for_inference_compiler) def code_to_accelerate(): ...
可组合后端
TorchDynamo 包含多个后端,可以通过 torch._dynamo.list_backends()
列出它们。你可以使用以下代码将这些后端结合起来:
from torch._dynamo import lookup_backend def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): try: trt_compiled = lookup_backend("tensorrt")(gm, example_inputs) if trt_compiled is not None: return trt_compiled except Exception: pass # first backend failed, try something else... try: inductor_compiled = lookup_backend("inductor")(gm, example_inputs) if inductor_compiled is not None: return inductor_compiled except Exception: pass return gm.forward