在ATen IR上进行图变换的编写

阶段通过

由于ATen IR处于FX Graph和GraphModule的层级上,因此可以很容易地将为FX Graph编写的任何转换应用到ATen IR中。如果你熟悉编写FX图变换,那么这个过程对你来说会很熟悉。

编写转换的最直接方法是通过遍历给定的图形并直接操作其中的节点。

例如,假设我们要将 torch.ops.aten.add.Tensor() 替换为 torch.ops.aten.mul.Tensor()

import torch

def replace_add_with_mul(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
    for node in gm.graph.nodes:
        if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
            node.target = torch.ops.aten.mul.Tensor

我们也可以通过Graph 文档中提供的FX 工具函数来删除和添加新的节点。例如,如果我们想在add 调用之后插入一个 torch.ops.aten.relu.default():

import torch

def insert_relu_after_add(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
    for node in gm.graph.nodes:
        if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:

            # Specifies the insertion point. Any nodes added to the graph within
            # this scope will be inserted after `node`
            with gm.graph.inserting_after(node):
                # Insert a new `call_function` node with op `torch.ops.aten.relu.default`
                new_relu_node = gm.graph.call_function(torch.ops.aten.relu.default, args=(node,))
                # Replace all the places that use `node` to now use the `new_relu_node`
                node.replace_all_uses_with(new_relu_node)

通常来说,变换可以大致归为几类:

轴A:1. 创建一对多映射(例如分解) 2. 创建多对一映射(例如融合)

轴B:1. 正向迭代(如形状传播)2. 反向迭代(如消除死代码)

轴 C:1. 依赖本地节点信息(如输出变体转换)2. 依赖全局图信息(如内存规划)

我们对这些用例使用频率的预测如下:1. A.1、B.1、C.1 2. A.2 3. B.2、C.2

虽然我们可以通过直接操作图形来完成所有的图变换,但也为一级和二级用例提供了一些辅助工具,以便更轻松地使用。

Transformer

对于一级用例(例如创建一对一到多的映射、执行正向迭代以及查看本地节点信息),我们可以使用Transformer 类来执行每个节点并重新创建一个图,同时应用指定的转换。

一对一传球

例如,一对一映射的情况下,如果我们想用操作 B 替换操作 A,可以运行 GraphModule,在每次遇到操作 A 时返回操作 B。

例如:

class ReplaceAddWithMul(torch.fx.Transformer):
    def call_function(self, target, args, kwargs):
        if target != torch.ops.aten.add.Tensor:
            return super().call_function(target, args, kwargs)
        return super().call_function(torch.ops.aten.mul.Tensor, args, kwargs)

transformed_graph_module = ReplaceAddWithMul(graph_module).transform()

调用 super().call_function(target, args, kwargs, meta) 会创建一个 call_function FX 节点,并返回使用给定参数运行操作的结果。

一对多传递

如果我们想要进行一對一到多的映射,比如將操作 A 替換為另外兩個操作 B 和 C,那麼我們會調用兩次 super().call_function 來創建兩個 FX 节点,一個使用操作 B,另一個使用操作 C,並返回操作 C 的運行結果。

例如:

class ReplaceAddWithMulSub(torch.fx.Transformer):
"""
    Original:
        def f(x, y):
            return x + y

    After pass:
        def f(x, y):
            z = x * y
            return z - y
    """
    def call_function(self, target, args, kwargs):
        if target != torch.ops.aten.add.Tensor:
            return super().call_function(target, args, kwargs)

        x, y = args

        mul_res = super().call_function(torch.ops.aten.mul.Tensor, args, {})
        return super().call_function(torch.ops.aten.sub.Tensor, (mul_res, y), {})

transformed_graph_module = ReplaceAddWithMulSub(graph_module).transform()

一对一传球

如果我们想删除一个操作,可以直接返回传入函数的值。

class RemoveDetachPass(torch.fx.Transformer):
    def call_function(self, target, args, kwargs):
        if target not in (
            torch.ops.aten.detach.default,
            torch.ops.aten.detach_copy.default,
        ):
            return super().call_function(target, args, kwargs, meta)

        assert len(args) == 1
        return args[0]

transformed_graph_module = RemoveDetachPass(graph_module).transform()

利用本地信息

利用本地节点信息的一个例子是,如果我们想将图中的所有标量转换为张量,可以运行给定的 fx.GraphModule,并将每个包含标量的参数转换为张量。这可能看起来像:

def args_map(target, fn, args, kwargs):
    assert isinstance(args, tuple)
    assert isinstance(kwargs, dict)
    args = list(args)
    kwargs = kwargs.copy()

    # Update the argument based on the function passed
    def update(key, args, schema):
        args[key] = fn(args[key], schema)

    # Update each argument in the schema
    for i, schema in enumerate(target._schema.arguments):
        if schema.name in kwargs:
            update(schema.name, kwargs, schema)
        elif not schema.kwarg_only and i < len(args):
            update(i, args, schema)
    return tuple(args), kwargs

class ScalarToTensorPass(torch.fx.Transformer):
    def call_function(self, target, args, kwargs):
        breakpoint()
        def try_coerce(value, arg):
            return (
                torch.tensor(value)
                if isinstance(value, (float, int, bool))
                and type(arg.type) == torch.TensorType
                else value
            )

        args, kwargs = args_map(target, try_coerce, args, kwargs)
        return super().call_function(target, args, kwargs)

transformed_graph_module = ScalarToTensorPass(graph_module).transform()

子图重写器

为了创建多对一映射,我们可以使用FX的子图重写器。给定一个pattern,它会生成与该模式匹配的操作符子图,并将每个匹配的子图替换为replacement

注意:

This is an inplace operation.

patternreplacement 输入必须是可调用的函数或包含相同 ATen 操作符的 GraphModules,以便子图重写器可以在图中找到正确的模式。在匹配过程中,传递给 pattern/replacement 可调用输入的参数将被视为通配符。

示例:

from torch.fx import subgraph_rewriter

def replace_patterns(graph_module):
    def pattern(x, y):
        x = torch.ops.aten.add.Tensor(x, y)
        x = torch.ops.aten.mul.Tensor(x, y)
        return x

    def replacement(x, y):
        return torch.ops.aten.sub.Tensor(x, y)

replaced_patterns = subgraph_rewriter.replace_pattern_with_filters(
    traced_module, pattern, replacement
)

子图重写器返回一个 ReplacedPatterns 列表:

@dataclass
class ReplacedPatterns:
    # Node from which the match was found
    anchor: Node
    # Maps nodes in the pattern subgraph to nodes in the larger graph
    nodes_map: Dict[Node, Node]
    # List of nodes that were added into the graph
    replacements: List[Node]

注意:

The nodes created by the subgraph rewriter will not have the metadata that
is populated in the matched nodes, but you can use
`ReplacedPatterns.nodes_map` to find the nodes in the original graph that
were matched, and `ReplacedPatterns.replacements` to find the nodes that
were replaced in the transformed graph.

密码管理器

`PassManager`<https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/infra/pass_manager.py> 是一个用于在给定的图模块上运行多个转换(passes)的类。当我们初始化 PassManager 实例时,会传入一系列要执行的转换,并设置一些标志。为了在一个图模块上运行这些转换集合,可以直接将该图模块传递给 PassManager 实例。

示例:

from torch.fx.passes.infra.pass_manager import PassManager

pm = PassManager(
    passes=[replace_add_with_div, replace_div_with_mul],
    run_checks_after_each_pass=True,
    suppress_check_failures=False,
)
graph_module_out = pm(graph_module)

为了添加一组通用的检查并在每次遍历后执行,我们可以调用函数set_checks(check: Callable),该函数接受一个可调用函数作为参数。如果设置了run_checks_after_each_pass标志,则在每次对图模块进行遍历时都会调用check

示例:

pm = PassManager(passes=[replace_add_with_div, replace_div_with_mul])

def check_div_target(graph_module):
    for node in graph_module.graph.nodes:
        if node.op == "call_function" and node.target != torch.div:
            raise ValueError("Target should be div!")

pm.add_checks(check_div_target)

pm(graph_module)    # raises ValueError after replace_div_with_mul pass

分区器

我们可以使用几种常见的基于FX图的分区器来划分图形。

子图匹配器

为了在图中找到匹配特定模式的子图,可以使用FX的`SubgraphMatcher` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/utils/matcher_utils.py>。

类属性:

  • pattern (Graph): 目标匹配模式。图中的占位符节点在匹配时会被视为通配符。

  • match_output (bool): 如果为 True,模式图中的输出节点会被视为目标模式的一部分;如果为 False,则在匹配过程中忽略输出节点。

  • match_placeholder (bool): 如果为 True,模式图中的占位符节点会被视为目标模式的一部分;如果为 False,则占位符节点会作为通配符使用。

  • remove_overlapping_matches (bool): 如果设置为 True,在出现匹配项重叠时,只会返回第一个匹配项。

  • ignore_literals (bool): 如果为 True,则不检查字面量是否相等,并将其视为通配符处理。

示例:

from torch.fx.passes.utils.matcher_utils import SubgraphMatcher

class LargeModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._weight = torch.nn.Parameter(torch.ones(3, 3))
        self._bias = torch.nn.Parameter(torch.ones(3, 3))

    def forward(self, x):
        return torch.ops.aten.addmm.default(self._bias, x, self._weight)

large_model_graph = torch.export(LargeModel(), inputs).graph

class PatternModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._weight_1 = torch.nn.Parameter(torch.ones(5, 5))
        self._bias_1 = torch.nn.Parameter(torch.ones(5, 5))

    def forward(self, x):
        return torch.ops.aten.addmm.default(self._bias_1, x, self._weight_1)

pattern_graph = torch.export(PatternModel(), inputs).graph

subgraph_matcher = SubgraphMatcher(pattern_graph)
match_result = subgraph_matcher.match(large_model_graph)

match 函数返回一个包含 InternalMatch 对象的列表。

@dataclass
class InternalMatch():
    # Nodes from which the match was found
    anchors: List[Node]
    # Maps nodes in the pattern subgraph to nodes in the larger graph
    nodes_map: Dict[Node, Node] = field(default_factory=dict)
    # Nodes in target graph that are matched placeholder in pattern
    placeholder_nodes: List[Node] = field(default_factory=list)
    # Nodes in matched subgraph returned by output
    returning_nodes: List[Node] = field(default_factory=list)

基于能力的分区器

为了找到支持特定不变量的最大子图节点,我们可以使用 FX 的 `CapabilityBasedPartitioner` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/infra/partitioner.py#L34>。

类属性

  • graph_module (torch.fx.GraphModule): 需要进行分区的图模块。

  • operator_support (OperatorSupportBase): 用于判断图中节点在分区中是否被支持的对象。

  • allows_single_node_partition (bool): 如果为 True,允许创建单节点分区。

  • non_compute_ops (Optional[Sequence[str]]): 一组被视为“非计算”操作(例如 torch.ops.aten.view_operator.getitem),这样分区器就不会生成仅包含这些非计算操作的图。

  • allowed_single_node_partition_ops (Optional[Sequence[str]]): 允许在单个节点分区中执行的操作集合。

The `OperatorSupportBase` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/operator_support.py#LL28C1-L28C1> class is used by the partitioner to determine whether a specific node in the graph belongs to a particular partition. This determination is made by overriding the `is_node_supported` function. You can chain multiple `OperatorSupportBase` instances using `chain` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/operator_support.py#L150> (which returns False if any `OperatorSupportBase` instance returns False) and `any_chain` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/operator_support.py#L164> (which returns True if any `OperatorSupportBase` instance returns True).

示例:

from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase

class AddMulOperatorSupport(OperatorSupportBase):
    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
        return node.op == "call_function" and node.target in [
            torch.ops.aten.add.Tensor, torch.ops.aten.mul.Tensor,
        ]

capability_partitioner = CapabilityBasedPartitioner(
    graph_module,
    op_support,
)

# Returns a list of partitions (list of nodes that belong in each partition)
partition_list = capability_partitioner.propose_partitions()
# Fuses the partitions into graph modules and inserts `call_module` nodes in the graph
fused_graph_module = capability_partitioner.fuse_partitions(partition_list)
本页目录