自定义后端

概述

torch.compile 提供了一种简便的方式来让用户定义自定义后端。

一个后端函数的契约是 (gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]) -> Callable

后端函数可以被 TorchDynamo 调用,它是 torch.compile 的图跟踪组件,在对 FX 图进行跟踪之后调用。这些函数需要返回一个与所跟踪的 FX 图等价的编译后的函数。返回的对象应该遵循原始 torch.fx.GraphModuleforward 函数的契约: (*args: torch.Tensor) -> List[torch.Tensor]

为了使 TorchDynamo 能调用你的后端函数,请在调用 torch.compile 时将该函数作为 backend 参数传递。例如:

import torch

def my_custom_backend(gm, example_inputs):
    return gm.forward

def f(...):
    ...

f_opt = torch.compile(f, backend=my_custom_backend)

@torch.compile(backend=my_custom_backend)
def g(...):
    ...

请参见下方的更多示例。

注册自定义后端

你可以使用 register_backend 装饰器来注册你的后端,例如:

from torch._dynamo import register_backend

@register_backend
def my_compiler(gm, example_inputs):
    ...

除了register_backend装饰器之外,如果你的后端位于另一个Python包中,还可以通过该包的入口点来注册后端。这种方式允许一个包为另一个包注册插件。

提示

你可以查阅Python 包装文档,了解更多关于entry_points的内容。

要通过 entry_points 注册你的后端,你可以在包的 setup.py 文件中将后端函数添加到 torch_dynamo_backends 入口点组,如下所示:

...
setup(
    ...
    'torch_dynamo_backends': [
        'my_compiler = your_module.submodule:my_compiler',
    ]
    ...
)

请将my_compiler替换为你的后端名称,并在等号之后的部分替换为你后端函数的模块和函数名。安装包后,入口点会被添加到你的Python环境中。当你调用torch.compile(model, backend="my_compiler")时,PyTorch会首先查找通过register_backend注册的名为my_compiler的后端。如果没有找到,则会在所有通过entry_points注册的后端中继续查找。

注册有两大用途:

  • 你可以将包含后端函数名称的字符串传递给torch.compile,而不需要传递函数本身,例如:torch.compile(model, backend="my_compiler")

  • 它在使用minifier时是必需的。任何由 minifier 生成的代码都必须通过一个 import 语句来调用你注册后端函数的代码。

AOTAutograd 之后的自定义后端

可以为AOTAutograd定义自定义后端(而非TorchDynamo),这样做主要有两个原因:

  • 用户可以定义支持模型训练的后端,因为AOTAutograd可以为编译生成反向图。

  • AOTAutograd 生成由 核心 Aten 操作 组成的 FX 图。因此,自定义后端只需支持核心 Aten 操作集,这比整个 torch/Aten 操作集要小得多。

torch._dynamo.backends.common.aot_autograd包装你的后端,并像之前一样使用torch.compilebackend关键字参数。由aot_autograd包装的后端函数应遵循之前的约定。

后端函数通过 aot_autogradfw_compiler(前向编译器)或 bw_compiler(反向编译器)关键字参数传递。如果没有指定 bw_compiler,则默认使用 fw_compiler 作为反向编译函数。

需要注意的是,AOTAutograd要求后端返回的编译函数必须是“带包装的”。可以使用functorch.compile.make_boxed_func来实现这一点。

例如:

from torch._dynamo.backends.common import aot_autograd
from functorch.compile import make_boxed_func

def my_compiler(gm, example_inputs):
    return make_boxed_func(gm.forward)

my_backend = aot_autograd(fw_compiler=my_compiler)  # bw_compiler=my_compiler

model_opt = torch.compile(model, backend=my_backend)

例子

调试后端

如果你想更好地理解编译过程,可以创建一个自定义编译器(本节称为后端),它会漂亮地打印出从Dynamo字节码分析中提取的fx GraphModule,并返回一个forward() 可调用对象。

例如:

from typing import List
import torch
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("my_compiler() called with FX graph:")
    gm.graph.print_tabular()
    return gm.forward  # return a python callable
@torch.compile(backend=my_compiler)
def fn(x, y):
    a = torch.cos(x)
    b = torch.sin(y)
    return a + b
fn(torch.randn(10), torch.randn(10))

运行上述示例会生成以下输出:

my_compiler() called with FX graph:
opcode         name    target                                                  args        kwargs
-------------  ------  ------------------------------------------------------  ----------  --------
placeholder    x       x                                                       ()          {}
placeholder    y       y                                                       ()          {}
call_function  cos     <built-in method cos of type object at 0x7f1a894649a8>  (x,)        {}
call_function  sin     <built-in method sin of type object at 0x7f1a894649a8>  (y,)        {}
call_function  add     <built-in function add>                                 (cos, sin)  {}
output         output  output                                                  ((add,),)   {}

这同样适用于torch.nn.Module,如下所示:

from typing import List
import torch
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("my_compiler() called with FX graph:")
    gm.graph.print_tabular()
    return gm.forward  # return a python callable
class MockModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = torch.nn.ReLU()
    def forward(self, x):
        return self.relu(torch.cos(x))
mod = MockModule()
optimized_mod = torch.compile(mod, backend=my_compiler)
optimized_mod(torch.randn(10))

让我们再看一个关于控制流的例子:

from typing import List
import torch
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("my_compiler() called with FX graph:")
    gm.graph.print_tabular()
    return gm.forward  # return a python callable
@torch.compile(backend=my_compiler)
def toy_example(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b
for _ in range(100):
    toy_example(torch.randn(10), torch.randn(10))

运行此示例会生成以下输出:

my_compiler() called with FX graph:
opcode         name     target                                                  args              kwargs
-------------  -------  ------------------------------------------------------  ----------------  --------
placeholder    a        a                                                       ()                {}
placeholder    b        b                                                       ()                {}
call_function  abs_1    <built-in method abs of type object at 0x7f8d259298a0>  (a,)              {}
call_function  add      <built-in function add>                                 (abs_1, 1)        {}
call_function  truediv  <built-in function truediv>                             (a, add)          {}
call_method    sum_1    sum                                                     (b,)              {}
call_function  lt       <built-in function lt>                                  (sum_1, 0)        {}
output         output   output                                                  ((truediv, lt),)  {}

my_compiler() called with FX graph:
opcode         name    target                   args         kwargs
-------------  ------  -----------------------  -----------  --------
placeholder    b       b                        ()           {}
placeholder    x       x                        ()           {}
call_function  mul     <built-in function mul>  (b, -1)      {}
call_function  mul_1   <built-in function mul>  (x, mul)     {}
output         output  output                   ((mul_1,),)  {}

my_compiler() called with FX graph:
opcode         name    target                   args       kwargs
-------------  ------  -----------------------  ---------  --------
placeholder    b       b                        ()         {}
placeholder    x       x                        ()         {}
call_function  mul     <built-in function mul>  (x, b)     {}
output         output  output                   ((mul,),)  {}

The order of the last two graphs is nondeterministic depending
on which one is encountered first by the just-in-time compiler.

快速后台

集成一个提供更好性能的自定义后端非常简单,我们将使用 optimize_for_inference 来整合一个真实的后端。

def optimize_for_inference_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    scripted = torch.jit.script(gm)
    return torch.jit.optimize_for_inference(scripted)

然后你应该能够使用以下方法优化任何现有代码:

@torch.compile(backend=optimize_for_inference_compiler)
def code_to_accelerate():
    ...

可组合后端

TorchDynamo 包含多个后端,可以通过 torch._dynamo.list_backends() 列出它们。你可以使用以下代码将这些后端结合起来:

from torch._dynamo import lookup_backend
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    try:
        trt_compiled = lookup_backend("tensorrt")(gm, example_inputs)
        if trt_compiled is not None:
            return trt_compiled
    except Exception:
        pass
    # first backend failed, try something else...
    try:
        inductor_compiled = lookup_backend("inductor")(gm, example_inputs)
        if inductor_compiled is not None:
            return inductor_compiled
    except Exception:
        pass
    return gm.forward
本页目录