torch.fx

概述

FX 是一个开发工具包,用于转换 nn.Module 实例。它包含三个主要组件:符号追踪器、中间表示和 Python 代码生成。以下是这些组件的示例演示:

import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)

module = MyModule()

from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)

# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():
    %x : [num_users=1] = placeholder[target=x]
    %param : [num_users=1] = get_attr[target=param]
    %add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp
"""

# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):
    param = self.param
    add = x + param;  x = param = None
    linear = self.linear(add);  add = None
    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
    return clamp
"""

符号跟踪器 对 Python 代码执行“符号执行”。它将假值(称为代理 Proxies)传递给代码。对这些代理的操作会被记录下来。有关符号跟踪的更多信息可以在 symbolic_trace()Tracer 文档中找到。

中间表示(Intermediate Representation,简称IR)是符号跟踪期间记录的操作容器。它由一系列节点组成,这些节点代表函数输入、调用站点(包括函数、方法或torch.nn.Module 实例),以及返回值。关于IR的更多信息可以在 Graph 文档中找到。IR是应用变换的基础格式。

Python代码生成 是使FX成为Python到Python(或模块到模块)转换工具包的关键。对于每个图IR,我们可以创建与其语义匹配的有效Python代码。此功能封装在GraphModule中,它是持有Graph以及从该图生成的forward方法的torch.nn.Module实例。

这些组件组成的管道(符号跟踪 -> 中间表示 -> 转换 -> Python 代码生成)共同构成了 FX 的 Python 到 Python 转换管道。此外,这些组件可以单独使用。例如,符号跟踪可以在隔离状态下用于捕获代码的一种形式以进行分析(而非转换)。代码生成可用于根据配置文件等程序化地生成模型。FX 具有许多用途!

几个示例转换可以在示例仓库中找到。

编写转换

什么是FX变换?本质上,它是一个这样的函数。

import torch
import torch.fx

def transform(m: nn.Module,
              tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
    # Step 1: Acquire a Graph representing the code in `m`

    # NOTE: torch.fx.symbolic_trace is a wrapper around a call to
    # fx.Tracer.trace and constructing a GraphModule. We'll
    # split that out in our transform to allow the caller to
    # customize tracing behavior.
    graph : torch.fx.Graph = tracer_class().trace(m)

    # Step 2: Modify this Graph or create a new one
    graph = ...

    # Step 3: Construct a Module to return
    return torch.fx.GraphModule(m, graph)

你的转换会接收一个torch.nn.Module,从中获取一个Graph,进行一些修改,并返回一个新的torch.nn.Module。你应该将你的FX转换返回的torch.nn.Module视为与常规的torch.nn.Module相同——你可以将其传递给另一个FX转换,可以传递给TorchScript,或者直接运行它。确保你的FX转换的输入和输出是一个torch.nn.Module将使其具有组合性。

注意

你也可以选择修改现有的GraphModule,而不需要创建新的,具体方法如下:

import torch
import torch.fx

def transform(m : nn.Module) -> nn.Module:
    gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)

    # Modify gm.graph
    # <...>

    # Recompile the forward() method of `gm` from its Graph
    gm.recompile()

    return gm

请注意,你必须调用GraphModule.recompile(),以确保生成的forward()方法与修改后的Graph保持同步。

鉴于你传入了一个已经被追踪为 torch.nn.ModuleGraph,现在有两种主要的方法可以用来构建一个新的 Graph

图的基础简介

关于图的语义的完整处理可以在 Graph 文档中找到,但这里我们将介绍基础知识。一个 Graph 是一种数据结构,表示在 GraphModule 上的方法。所需的信息包括:

  • 这个方法的输入参数是什么?

  • 方法内执行了哪些操作?

  • 该方法的返回值是什么?

这三个概念都使用Node实例来表示。让我们通过一个简短的例子来看看具体含义:

import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(torch.sum(
            self.linear(x + self.linear.weight).relu(), dim=-1), 3)

m = MyModule()
gm = torch.fx.symbolic_trace(m)

gm.graph.print_tabular()

在这里,我们为了演示目的定义了一个名为MyModule的模块,对其进行实例化和符号跟踪,然后调用Graph.print_tabular()方法来打印显示该Graph节点的表格。

操作码

姓名

目标

参数

kwargs

占位符

x

x

()

{}

获取属性

线性权重

linear权重

()

{}

调用函数

添加_1

add(内置函数)

(x, 线性权重)

{}

调用模块

线性_1

线性

(add_1,)

{}

调用方法

ReLU_1

ReLU

(linear_1)

{}

调用函数

总述_1

<内置求和方法>

relu_1

{"dim": -1}

调用函数

topk_1

<内置方法 topk >

sum(1, 3)

{}

输出结果

输出结果

输出结果

(topk_1,)

{}

我们可以使用这些信息来回答之前提出的问题。
  • 方法的输入是什么?在FX中,方法输入通过特殊的placeholder节点指定。在这种情况下,我们有一个具有targetx的单个placeholder节点,这意味着我们有一个名为x的单一(非self)参数。

  • 方法内的操作有哪些?get_attrcall_functioncall_modulecall_method 节点代表了方法中的操作。所有这些节点的完整语义解释可以在Node 文档中找到。

  • 方法的返回值是什么?在Graph中,返回值由一个特殊的output节点指定。

既然我们现在掌握了代码在FX中的基本表示方法,就可以开始探索如何编辑一个Graph了。

图形操作

直接图 Manipulation

构建新的Graph的一种方法是直接操作现有的图。为此,我们可以获取符号跟踪得到的Graph并进行修改。例如,假设我们希望将torch.add()调用替换为torch.mul()调用。

import torch
import torch.fx

# Sample module
class M(torch.nn.Module):
    def forward(self, x, y):
        return torch.add(x, y)

def transform(m: torch.nn.Module,
              tracer_class : type = fx.Tracer) -> torch.nn.Module:
    graph : fx.Graph = tracer_class().trace(m)
    # FX represents its Graph as an ordered list of
    # nodes, so we can iterate through them.
    for node in graph.nodes:
        # Checks if we're calling a function (i.e:
        # torch.add)
        if node.op == 'call_function':
            # The target attribute is the function
            # that call_function calls.
            if node.target == torch.add:
                node.target = torch.mul

    graph.lint() # Does some checks to make sure the
                 # Graph is well-formed.

    return fx.GraphModule(m, graph)

我们还可以执行更复杂的 Graph 重写操作,例如删除或添加节点。为了帮助这些转换,FX 提供了一些用于图变换的实用函数,这些函数可以在Graph文档中找到。下面提供了一个使用这些 API 添加 torch.relu() 调用的例子。

# Specifies the insertion point. Any nodes added to the
# Graph within this scope will be inserted after `node`
with traced.graph.inserting_after(node):
    # Insert a new `call_function` node calling `torch.relu`
    new_node = traced.graph.call_function(
        torch.relu, args=(node,))

    # We want all places that used the value of `node` to
    # now use that value after the `relu` call we've added.
    # We use the `replace_all_uses_with` API to do this.
    node.replace_all_uses_with(new_node)

对于仅包含替换的简单转换,你还可以使用子图重写器

利用replace_pattern()进行子图重写

FX 还在直接图操作的基础上提供了另一层自动化。replace_pattern() API 实质上是一个用于编辑Graph的“查找/替换”工具。它允许你指定一个patternreplacement函数,并遍历这些函数,找到在pattern图中的操作组实例,并用replacement图的副本替换它们。这可以帮助大大简化繁琐的图操作代码,在转换变得更加复杂时,这种代码可能会变得难以管理。

代理/回溯

另一种操作 Graph 的方式是重用符号跟踪中使用的 Proxy 机制。例如,假设我们想编写一个将 PyTorch 函数分解为更小的操作的转换器。它会将每个 F.relu(x) 调用转换为 (x > 0) * x。一种可能是执行必要的图重写,在 F.relu 后插入比较和乘法操作,然后清理原始的 F.relu。然而,我们可以通过使用 Proxy 对象自动记录操作到 Graph 中来自动化这个过程。

要使用此方法,我们编写常规的 PyTorch 代码来插入所需的操作,并用 Proxy 对象作为参数调用该代码。这些 Proxy 对象会捕获在其上执行的操作并将它们追加到 Graph 中。

# Note that this decomposition rule can be read as regular Python
def relu_decomposition(x):
    return (x > 0) * x

decomposition_rules = {}
decomposition_rules[F.relu] = relu_decomposition

def decompose(model: torch.nn.Module,
              tracer_class : type = fx.Tracer) -> torch.nn.Module:
"""
    Decompose `model` into smaller constituent operations.
    Currently,this only supports decomposing ReLU into its
    mathematical definition: (x > 0) * x
    """
    graph : fx.Graph = tracer_class().trace(model)
    new_graph = fx.Graph()
    env = {}
    tracer = torch.fx.proxy.GraphAppendingTracer(new_graph)
    for node in graph.nodes:
        if node.op == 'call_function' and node.target in decomposition_rules:
            # By wrapping the arguments with proxies,
            # we can dispatch to the appropriate
            # decomposition rule and implicitly add it
            # to the Graph by symbolically tracing it.
            proxy_args = [
                fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args]
            output_proxy = decomposition_rules[node.target](*proxy_args)

            # Operations on `Proxy` always yield new `Proxy`s, and the
            # return value of our decomposition rule is no exception.
            # We need to extract the underlying `Node` from the `Proxy`
            # to use it in subsequent iterations of this transform.
            new_node = output_proxy.node
            env[node.name] = new_node
        else:
            # Default case: we don't have a decomposition rule for this
            # node, so just copy the node over into the new graph.
            new_node = new_graph.node_copy(node, lambda x: env[x.name])
            env[node.name] = new_node
    return fx.GraphModule(model, new_graph)

除了避免显式图操作外,使用 Proxy 还可以让你将重写规则指定为原生 Python 代码。对于需要大量重写规则的转换(例如 vmap 或 grad),这通常可以提高规则的可读性和可维护性。需要注意的是,在调用 Proxy 时,我们还传递了一个指向底层变量图的 tracer。这样做是为了在图中的操作是 n 元(例如 add 是一个二元运算符)的情况下,Proxy 调用不会创建多个图追踪器实例,这可能会导致意外的运行时错误。我们特别推荐在这种情况下使用 Proxy,尤其是在不能安全地假设底层操作符为一元的情况下。

一个使用Proxy进行Graph操作的示例可以在这里找到。

解释器模式

FX 中一个有用的代码组织模式是遍历 Graph 中的所有 Node 并执行它们。这种模式可以用于多种用途,例如通过运行时分析流经图的值或使用 Proxy 重新跟踪来转换代码。例如,假设我们希望运行一个 GraphModule 并在运行时记录我们在节点上看到的 torch.Tensor 的形状和数据类型属性。这可能看起来像:

import torch
import torch.fx
from torch.fx.node import Node

from typing import Dict

class ShapeProp:
"""
    Shape propagation. This class takes a `GraphModule`.
    Then, its `propagate` method executes the `GraphModule`
    node-by-node with the given arguments. As each operation
    executes, the ShapeProp class stores away the shape and
    element type for the output values of each operation on
    the `shape` and `dtype` attributes of the operation's
    `Node`.
    """
    def __init__(self, mod):
        self.mod = mod
        self.graph = mod.graph
        self.modules = dict(self.mod.named_modules())

    def propagate(self, *args):
        args_iter = iter(args)
        env : Dict[str, Node] = {}

        def load_arg(a):
            return torch.fx.graph.map_arg(a, lambda n: env[n.name])

        def fetch_attr(target : str):
            target_atoms = target.split('.')
            attr_itr = self.mod
            for i, atom in enumerate(target_atoms):
                if not hasattr(attr_itr, atom):
                    raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
                attr_itr = getattr(attr_itr, atom)
            return attr_itr

        for node in self.graph.nodes:
            if node.op == 'placeholder':
                result = next(args_iter)
            elif node.op == 'get_attr':
                result = fetch_attr(node.target)
            elif node.op == 'call_function':
                result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
            elif node.op == 'call_method':
                self_obj, *args = load_arg(node.args)
                kwargs = load_arg(node.kwargs)
                result = getattr(self_obj, node.target)(*args, **kwargs)
            elif node.op == 'call_module':
                result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))

            # This is the only code specific to shape propagation.
            # you can delete this `if` branch and this becomes
            # a generic GraphModule interpreter.
            if isinstance(result, torch.Tensor):
                node.shape = result.shape
                node.dtype = result.dtype

            env[node.name] = result

        return load_arg(self.graph.result)

如你所见,一个完整的FX解释器并不复杂,但非常有用。为了方便使用这一模式,我们提供了Interpreter类,该类通过方法重写封装了上述逻辑,并允许覆盖解释器执行的某些方面。

除了执行操作之外,我们还可以通过将 Proxy 值传递给解释器来生成新的 Graph。同样地,我们提供了 Transformer 类来涵盖这种模式。Transformer 的行为类似于 Interpreter,但你不会调用 run 方法来获取 Module 的具体输出值,而是会调用Transformer.transform() 方法以返回一个新的 GraphModule,该模块经过了你安装的任何转换规则(作为重写的方法)。

解释器模式示例

调试

介绍

在编写转换的过程中,我们的代码往往不会完全正确。在这种情况下,我们需要进行一些调试工作。关键是倒着来:首先,检查调用生成模块的结果以验证其正确性或错误;然后,检查和调试生成的代码;最后,调试导致生成代码的整个转换过程。

如果您不熟悉调试器,请参见辅助部分中的“可用的调试器”。

转换创作中的常见陷阱

  • 非确定性的set迭代顺序。在Python中,set 数据类型是无序的。例如,使用 set 来存储像 Node 这样的对象集合可能会导致意外的非确定性结果。一个例子是在一组 Node 上进行迭代并将它们插入到 Graph 中。由于 set 数据类型是无序的,输出程序中的操作顺序将是不确定的,并且在不同程序调用之间可能会发生变化。推荐使用按插入顺序排列的 dict 数据类型作为替代方案,在Python 3.7(以及cPython 3.6)中它是有序的。可以通过将要去重的值存储在 dict 的键中,来等效地使用 dict 替代集合。

验证模块的准确性

由于大多数深度学习模块的输出是由浮点数 torch.Tensor 实例组成的,因此判断两个 torch.nn.Module 的结果是否相等并不像简单的相等性检查那样直接。为了说明这一点,我们来看一个例子:

import torch
import torch.fx
import torchvision.models as models

def transform(m : torch.nn.Module) -> torch.nn.Module:
    gm = torch.fx.symbolic_trace(m)

    # Imagine we're doing some transforms here
    # <...>

    gm.recompile()

    return gm

resnet18 = models.resnet18()
transformed_resnet18 = transform(resnet18)

input_image = torch.randn(5, 3, 224, 224)

assert resnet18(input_image) == transformed_resnet18(input_image)
"""
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
"""

在这里,我们尝试使用==运算符检查两个深度学习模型的值是否相等。然而,由于该运算符返回的是张量而不是布尔值,并且浮点数比较应考虑误差范围(或epsilon)以处理浮点运算的非交换性问题(详情请参阅此处),因此这种做法是不明确的。我们可以改用torch.allclose(),它会根据相对和绝对容差阈值进行近似比较:

assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image))

这是我们工具箱中的第一个工具,用于检查转换后的模块是否按预期运行,并与参考实现进行比较。

调试生成的代码

由于FX在GraphModule上生成了forward()函数,使用传统的调试方法(如print语句或pdb)并不那么直接。幸运的是,我们有一些技术可以用来调试生成的代码。

使用 pdb

调用pdb来进入正在运行的程序。尽管表示Graph的代码不在任何源文件中,我们仍然可以在前向传递被调用时手动使用pdb进行调试。

import torch
import torch.fx
import torchvision.models as models

def my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
    graph = tracer_class().trace(inp)
    # Transformation logic here
    # <...>

    # Return new Module
    return fx.GraphModule(inp, graph)

my_module = models.resnet18()
my_module_transformed = my_pass(my_module)

input_value = torch.randn(5, 3, 224, 224)

# When this line is executed at runtime, we will be dropped into an
# interactive `pdb` prompt. We can use the `step` or `s` command to
# step into the execution of the next line
import pdb; pdb.set_trace()

my_module_transformed(input_value)

使用GraphModule中的to_folder函数

GraphModule.to_folder()GraphModule 中的一个方法,允许你将生成的 FX 代码导出到一个文件夹中。虽然通常复制前向传递代码就足够了(如在打印生成的代码中所示),但使用 to_folder 方法来检查模块和参数可能更方便。

m = symbolic_trace(M())
m.to_folder("foo", "Bar")
from foo import Bar
y = Bar()

运行上述示例后,我们可以查看并修改 foo/module.py 中的代码(例如添加 print 语句或使用 pdb),以便调试生成的代码。

调试变换

现在我们已经确定某个转换生成了错误的代码,是时候开始调试这个转换本身了。首先,我们将查看文档中的符号跟踪的限制部分。确认跟踪功能正常后,我们需要找出在GraphModule 转换过程中出现了什么问题。可能在编写转换中有快速解答,但如果找不到的话,有几种方法可以检查我们的追踪模块:

# Sample Module
class M(torch.nn.Module):
    def forward(self, x, y):
        return x + y

# Create an instance of `M`
m = M()

# Symbolically trace an instance of `M` (returns a GraphModule). In
# this example, we'll only be discussing how to inspect a
# GraphModule, so we aren't showing any sample transforms for the
# sake of brevity.
traced = symbolic_trace(m)

# Print the code produced by tracing the module.
print(traced)
# The generated `forward` function is:
"""
def forward(self, x, y):
    add = x + y;  x = y = None
    return add
"""

# Print the internal Graph.
print(traced.graph)
# This print-out returns:
"""
graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %add : [num_users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})
    return add
"""

# Print a tabular representation of the internal Graph.
traced.graph.print_tabular()
# This gives us:
"""
opcode         name    target                   args    kwargs
-------------  ------  -----------------------  ------  --------
placeholder    x       x                        ()      {}
placeholder    y       y                        ()      {}
call_function  add     <built-in function add>  (x, y)  {}
output         output  output                   (add,)  {}
"""

使用上述工具函数,我们可以对比应用变换前后的跟踪模块。有时候,简单的视觉比较就能帮助我们找到错误。如果还是无法确定问题所在,可以考虑使用像 pdb 这样的调试器作为下一步的解决方案。

参考上面的例子,考虑以下代码:

# Sample user-defined function
def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
    # Get the Graph from our traced Module
    g = tracer_class().trace(module)

"""
    Transformations on `g` go here
    """

    return fx.GraphModule(module, g)

# Transform the Graph
transformed = transform_graph(traced)

# Print the new code after our transforms. Check to see if it was
# what we expected
print(transformed)

根据上述示例,假设调用print(traced)显示我们的转换中存在错误。我们希望通过调试器找出问题所在。启动一个pdb会话后,在transform_graph(traced)上设置断点,并按s键“步入”该调用,以查看转换过程中发生了什么。

我们也可以通过修改 print_tabular 方法来打印图中节点的不同属性,例如节点的 input_nodesusers 属性。

可用的调试工具

最常用的 Python 调试器是 pdb。你可以在命令行中输入python -m pdb FILENAME.py来以“调试模式”启动你的程序,其中FILENAME是你想要调试的文件名。之后,你可以使用 pdb调试命令逐步运行你的程序。通常在开始时设置一个断点(b LINE-NUMBER),然后调用c让程序运行到该断点处,这样可以避免你需要逐行执行代码(使用 sn)来到达你想要检查的代码部分。或者,你可以在你希望中断的地方之前写入import pdb; pdb.set_trace()。如果你添加了pdb.set_trace(),当你运行程序时它会自动以调试模式启动。(换句话说,你可以直接在命令行中输入 python FILENAME.py 而不是 python -m pdb FILENAME.py)。一旦你在调试模式下运行你的文件,你可以使用某些命令逐步执行代码并检查程序的内部状态。有许多关于pdb的优秀教程在线上,包括RealPython的“使用Pdb进行Python调试”

像 PyCharm 或 VSCode 这样的集成开发环境通常都内置了调试器。在你的 IDE 中,你可以选择两种方式来使用 pdb:a) 在 IDE 中打开一个终端窗口(例如,在 VSCode 中为 View → Terminal),或者 b) 使用 IDE 内置的图形化调试工具(该工具通常基于 pdb)。

符号跟踪的局限性

FX 使用一种称为 符号跟踪 (又称 符号执行)的系统,以可转换和分析的形式捕获程序语义。该系统通过执行程序(实际上是torch.nn.Module或函数)并记录操作来进行跟踪。在执行过程中,流经程序的数据不是实际数据,而是符号(在FX中称为 Proxy),因此它是符号化的。

虽然符号跟踪适用于大多数神经网络代码,但它也存在一些局限性。

动态控制流

符号跟踪的主要限制在于它当前不支持动态控制流,例如循环或if语句,这些语句中的条件可能会根据程序的输入值发生变化。

例如,我们来看看以下程序:

def func_to_trace(x):
    if x.sum() > 0:
        return torch.relu(x)
    else:
        return torch.neg(x)

traced = torch.fx.symbolic_trace(func_to_trace)
"""
  <...>
  File "dyn.py", line 6, in func_to_trace
    if x.sum() > 0:
  File "pytorch/torch/fx/proxy.py", line 155, in __bool__
    return self.tracer.to_bool(self)
  File "pytorch/torch/fx/proxy.py", line 85, in to_bool
    raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""

如果语句的条件依赖于 x.sum() 的值,而 x.sum() 又依赖于输入参数 x 的值。由于 x 可以变化(例如传递一个新的输入张量给被跟踪的函数),这种情况被称为动态控制流。追踪回溯会向上遍历你的代码,展示出这种情形发生的位置。

静态控制流

另一方面,所谓的静态控制流是受支持的。静态控制流指的是那些在不同调用中值不会改变的循环或if语句。通常,在 PyTorch 程序中,这种类型的控制流出现在根据超参数对模型架构做出决策的代码里。例如:

import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self, do_activation : bool = False):
        super().__init__()
        self.do_activation = do_activation
        self.linear = torch.nn.Linear(512, 512)

    def forward(self, x):
        x = self.linear(x)
        # This if-statement is so-called static control flow.
        # Its condition does not depend on any input values
        if self.do_activation:
            x = torch.relu(x)
        return x

without_activation = MyModule(do_activation=False)
with_activation = MyModule(do_activation=True)

traced_without_activation = torch.fx.symbolic_trace(without_activation)
print(traced_without_activation.code)
"""
def forward(self, x):
    linear_1 = self.linear(x);  x = None
    return linear_1
"""

traced_with_activation = torch.fx.symbolic_trace(with_activation)
print(traced_with_activation.code)
"""
import torch
def forward(self, x):
    linear_1 = self.linear(x);  x = None
    relu_1 = torch.relu(linear_1);  linear_1 = None
    return relu_1
"""

if 语句 if self.do_activation 不依赖于任何函数输入,因此它是静态的。do_activation 可以被视为一个超参数,并且不同实例的代码会根据该参数的不同值而有所不同(例如:MyModule 的不同实例)。这是一个由符号跟踪支持的有效模式。

许多动态控制流的实例实际上具有静态控制流的语义。通过移除对输入值的数据依赖(例如,将值移动到Module属性或在符号跟踪期间将具体值绑定到参数),可以使这些实例支持符号跟踪:

def f(x, flag):
    if flag: return x
    else: return x*2

fx.symbolic_trace(f) # Fails!

fx.symbolic_trace(f, concrete_args={'flag': True})

对于真正动态的控制流,包含此代码的程序部分可以被追踪为对方法(参见使用Tracer类自定义跟踪)或函数(参见wrap())的调用,而不是直接通过它们进行追踪。

torch函数

FX 使用 __torch_function__ 作为拦截调用的机制(参见技术概述以获取更多信息)。一些函数,如内置 Python 函数或 math 模块中的函数,并未被 __torch_function__ 支持,但我们仍然希望在符号跟踪中捕获它们。例如:

import torch
import torch.fx
from math import sqrt

def normalize(x):
"""
    Normalize `x` by the size of the batch dimension
    """
    return x / sqrt(len(x))

# It's valid Python code
normalize(torch.rand(3, 4))

traced = torch.fx.symbolic_trace(normalize)
"""
  <...>
  File "sqrt.py", line 9, in normalize
    return x / sqrt(len(x))
  File "pytorch/torch/fx/proxy.py", line 161, in __len__
    raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""

此错误表明内置函数 len 不被支持。我们可以使用 wrap() API,将这类函数在追踪信息中以直接调用的形式记录下来。

torch.fx.wrap('len')
torch.fx.wrap('sqrt')

traced = torch.fx.symbolic_trace(normalize)

print(traced.code)
"""
import math
def forward(self, x):
    len_1 = len(x)
    sqrt_1 = math.sqrt(len_1);  len_1 = None
    truediv = x / sqrt_1;  x = sqrt_1 = None
    return truediv
"""

使用Tracer类进行自定义追踪

The Tracer 类是 symbolic_trace 的实现基础。可以通过继承和扩展 Tracer 类来自定义追踪行为,例如:

class MyCustomTracer(torch.fx.Tracer):
    # Inside here you can override various methods
    # to customize tracing. See the `Tracer` API
    # reference
    pass


# Let's use this custom tracer to trace through this module
class MyModule(torch.nn.Module):
    def forward(self, x):
        return torch.relu(x) + torch.ones(3, 4)

mod = MyModule()

traced_graph = MyCustomTracer().trace(mod)
# trace() returns a Graph. Let's wrap it up in a
# GraphModule to make it runnable
traced = torch.fx.GraphModule(mod, traced_graph)

叶模块

叶模块是指在符号跟踪中以调用形式出现而非被追踪的模块。默认情况下,叶模块包括标准的 torch.nn 模块实例。例如:

class MySpecialSubmodule(torch.nn.Module):
    def forward(self, x):
        return torch.neg(x)

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 4)
        self.submod = MySpecialSubmodule()

    def forward(self, x):
        return self.submod(self.linear(x))

traced = torch.fx.symbolic_trace(MyModule())
print(traced.code)
# `linear` is preserved as a call, yet `submod` is traced though.
# This is because the default set of "Leaf Modules" includes all
# standard `torch.nn` modules.
"""
import torch
def forward(self, x):
    linear_1 = self.linear(x);  x = None
    neg_1 = torch.neg(linear_1);  linear_1 = None
    return neg_1
"""

可以通过覆盖Tracer.is_leaf_module()来自定义叶模块的集合。

杂项

  • 目前,张量构造函数(例如 torch.zerostorch.onestorch.randtorch.randntorch.sparse_coo_tensor)不可追踪。

    • 确定性的构造函数(如 zerosones)可以使用,它们生成的值将作为常量嵌入到跟踪中。只有当这些构造函数的参数引用动态输入大小时,才会出现问题。在这种情况下,ones_likezeros_like 可能是更好的替代方案。

    • 非确定性构造函数(randrandn)会在跟踪中嵌入一个单一的随机值,这可能不是预期的行为。一种解决方法是将 torch.randn 包裹在一个 torch.fx.wrap 函数中,并调用该函数。

    @torch.fx.wrap
    def torch_randn(x, shape):
        return torch.randn(shape)
    
    def f(x):
        return x + torch_randn(x, 5)
    fx.symbolic_trace(f)
    
    • 这个问题可能会在未来的一个版本中得到修复。

  • 类型注解

    • Python 3 风格的类型注解(例如 func(x: torch.Tensor, y: int) -> torch.Tensor)是受支持的,并且会被符号跟踪保留。

    • 当前不支持 Python 2 风格的注释类型标注,例如:# type: (torch.Tensor, int) -> torch.Tensor

    • 当前不支持在函数内部的局部名称上使用注解。

  • 关于training标志和子模块的相关注意事项

    • 当使用像torch.nn.functional.dropout这样的功能时,训练参数通常会以self.training的形式传递。在FX追踪过程中,这个参数很可能被固定为一个常量值。

    import torch
    import torch.fx
    
    class DropoutRepro(torch.nn.Module):
      def forward(self, x):
        return torch.nn.functional.dropout(x, training=self.training)
    
    
    traced = torch.fx.symbolic_trace(DropoutRepro())
    print(traced.code)
    """
    def forward(self, x):
      dropout = torch.nn.functional.dropout(x, p = 0.5, training = True, inplace = False);  x = None
      return dropout
    """
    
    traced.eval()
    
    x = torch.randn(5, 3)
    torch.testing.assert_close(traced(x), x)
    """
    AssertionError: Tensor-likes are not close!
    
    Mismatched elements: 15 / 15 (100.0%)
    Greatest absolute difference: 1.6207983493804932 at index (0, 2) (up to 1e-05 allowed)
    Greatest relative difference: 1.0 at index (0, 0) (up to 0.0001 allowed)
    """
    
    • 然而,当使用标准的nn.Dropout()子模块时,训练标志会被封装。由于保留了nn.Module对象模型,因此可以修改这个标志。

    class DropoutRepro2(torch.nn.Module):
      def __init__(self):
        super().__init__()
        self.drop = torch.nn.Dropout()
    
      def forward(self, x):
        return self.drop(x)
    
    traced = torch.fx.symbolic_trace(DropoutRepro2())
    print(traced.code)
    """
    def forward(self, x):
      drop = self.drop(x);  x = None
      return drop
    """
    
    traced.eval()
    
    x = torch.randn(5, 3)
    torch.testing.assert_close(traced(x), x)
    
  • 由于这种差异,考虑将与training标志动态交互的模块标记为叶模块。

API参考

torch.fx.symbolic_trace(root, concrete_args=None)[源代码]

符号追踪API

给定一个 nn.Module 或函数实例 root,此函数将返回一个通过在跟踪 root 时记录的操作构建的 GraphModule

concrete_args 允许你部分地专门化你的函数,以移除控制流或数据结构。

例如:

def f(a, b):
    if b == True:
        return a
    else:
        return a*2

由于存在控制流,FX 通常无法进行追踪。但是,我们可以使用 concrete_args 来专门针对 b 的值进行追踪。

f = fx.symbolic_trace(f, concrete_args={'b': False})
assert f(3, False)  == 6

请注意,尽管你可以传入不同的b值,但这些值将会被忽略。

我们还可以使用concrete_args来移除函数中的数据结构处理。这会利用pytrees展开输入参数。为了避免过度专门化,对于不需要专门化的值应传递fx.PH。例如:

def f(x):
    out = 0
    for v in x.values():
        out += v
    return out
f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}})
assert f({'a': 1, 'b': 2, 'c': 4}) == 7
参数
  • root (Union[torch.nn.Module, Callable]) – 需要被追踪并转换为图表示的模块或函数。

  • concrete_args (Optional[Dict[str, any]]) – 需要部分特化的输入参数

返回值

root记录的操作生成的模块。

返回类型

GraphModule

注意

此 API 的向后兼容性得到了保证。

torch.fx.wrap(fn_or_name)[源代码]

此函数可以在模块级别的作用域中调用,用于将 fn_or_name 注册为“叶子函数”。一个“叶子函数”在 FX 跟踪中会保留为 CallFunction 节点,而不会被进一步追踪。

# foo/bar/baz.py
def my_custom_function(x, y):
    return x * x + y * y

torch.fx.wrap('my_custom_function')

def fn_to_be_traced(x, y):
    # When symbolic tracing, the below call to my_custom_function will be inserted into
    # the graph rather than tracing it.
    return my_custom_function(x, y)

此函数还可以等效地用作装饰器:

# foo/bar/baz.py
@torch.fx.wrap
def my_custom_function(x, y):
    return x * x + y * y

一个被包装的函数可以视为一个“叶函数”,类似于“叶模块”的概念。也就是说,在FX跟踪中,这些函数会被保留在调用层面,而不进行内部追踪。

参数

fn_or_name (Union[str, Callable]) – 在调用时插入图中的函数或全局函数的名称

注意

此 API 的向后兼容性得到了保证。

torch.fx.GraphModule(*args, **kwargs)[源代码]

GraphModule 是由 fx.Graph 生成的一个 nn.Module。它拥有一个 graph 属性,以及从该 graph 生成的 codeforward 属性。

警告

graph被重新赋值时,codeforward将自动再生。然而,如果你在不重新赋值graph属性的情况下编辑了graph的内容,你需要调用recompile()来更新生成的代码。

注意

此 API 的向后兼容性得到了保证。

__init__(root, graph, class_name='GraphModule')[源代码]

创建一个GraphModule。

参数
  • root (Union[torch.nn.Module, Dict[str, Any]) – root 可以是一个 nn.Module 实例,也可以是将字符串映射到任意属性类型的字典。如果 root 是一个 Module,则图中的节点的 target 字段中引用的基于 Module 的对象(通过全限定名)将会从 root 的 Module 层次结构复制到 GraphModule 的模块层次结构。如果 root 是一个字典,则节点中的 target 中的全限定名将直接在字典中查找,该对象将会被复制到 GraphModule 模块层次结构中的适当位置。

  • graph (Graph) – graph 包含此 GraphModule 用于代码生成的节点

  • class_name (str) – name 表示此 GraphModule 的调试名称。如果未设置,所有错误消息将默认显示为来自 GraphModule。将其设置为 root 的原始名称或在转换上下文中具有意义的名称可能会有所帮助。

注意

此 API 的向后兼容性得到了保证。

add_submodule(target, m)[源代码]

将给定的子模块添加到 self

如果它们是 target 的子路径且尚不存在,这会安装空的模块。

参数
  • target (str) – 新子模块的完整qualified字符串名称(参见nn.Module.get_submodule中的示例,了解如何指定完整qualified字符串。)

  • m (Module) – 当前 Module 中要安装的子模块对象

返回值
子模块是否可以被插入。

为了使该方法返回 True,target 标注的链中每个对象必须要么尚未存在,要么引用一个 nn.Module(而不是参数或其它属性)。

返回类型

bool

注意

此 API 的向后兼容性得到了保证。

属性 代码: str

返回由该 GraphModule 底层的 Graph 生成的 Python 代码。

delete_all_unused_submodules()[源代码]

self 中删除所有未使用的子模块。

如果满足以下任一条件,则认为一个模块被“使用”了:1. 它有被使用的子模块 2. 其 forward 方法通过一个 call_module 节点直接调用 3. 它有一个非模块属性,该属性从 get_attr 节点中被使用

此方法可以被调用以清理一个nn.Module,而无需手动为每个未使用的子模块调用delete_submodule

注意

此 API 的向后兼容性得到了保证。

delete_submodule(target)[源代码]

self中删除指定的子模块。

如果 target 不是有效的目标,模块将不会被删除。

参数

target (str) – 新子模块的完整qualified字符串名称(参见nn.Module.get_submodule中的示例,了解如何指定完整qualified字符串。)

返回值
目标字符串是否引用了某项内容

这个返回值表示我们要删除的是否是一个有效的子模块引用。如果返回值是False,则说明target不是一个有效的子模块引用。

返回类型

bool

注意

此 API 的向后兼容性得到了保证。

property graph: Graph

返回此 GraphModule 的底层 Graph

print_readable(print_output=True, include_stride=False, include_device=False, colored=False)[源代码]

返回当前 GraphModule 及其子模块生成的 Python 代码

警告

此 API 是实验性的,且支持向下兼容。

recompile()[源代码]

根据其graph属性重新编译此 GraphModule。在编辑了包含的graph之后,应调用此方法,否则此GraphModule生成的代码将变得过时。

注意

此 API 的向后兼容性得到了保证。

返回类型

Python代码

to_folder(folder, module_name='FxModule')[源代码]
将模块输出到指定的 folder 文件夹中,并使用 module_name 进行命名,以便于后续操作。

使用from <folder> import <module_name>进行导入

参数:

folder (Union[str, os.PathLike]): 写入代码的目标文件夹

module_name (str): 指定的顶级名称,用于标识 Module

写代码

警告

此 API 是实验性的,且支持向下兼容。

torch.fx.Graph(owning_module=None, tracer_cls=None, tracer_extras=None)[源代码]

Graph 是 FX 中间表示中主要的数据结构。它由一系列的 Node 组成,每个节点代表调用点(或其他语法构造)。这些 Node 一起构成一个有效的 Python 函数。

例如,以下示例代码

import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)

m = MyModule()
gm = torch.fx.symbolic_trace(m)

将生成如下图形:

print(gm.graph)
graph(x):
    %linear_weight : [num_users=1] = self.linear.weight
    %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})
    %linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
    %relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1})
    %topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})
    return topk_1

有关Graph中操作的语义,请参见Node

注意

此 API 的向后兼容性得到了保证。

__init__(owning_module=None, tracer_cls=None, tracer_extras=None)[源代码]

创建一个空的图。

注意

此 API 的向后兼容性得到了保证。

call_function(the_function, args=None, kwargs=None, type_expr=None)[源代码]

Graph中插入一个call_function Node。这个call_function节点表示对由the_function指定的Python可调用对象的调用。

参数
  • the_function (Callable[..., Any]) – 需要调用的函数。它可以是任何 PyTorch 操作符、Python 函数或 builtinsoperator 命名空间中的成员。

  • args (Optional[Tuple[Argument, ...]]) – 被调用函数所需的位置参数。

  • kwargs (Optional[Dict[str, Argument]]) – 传递给调用函数的关键词参数

  • type_expr (Optional[Any]) – 表示该节点输出的 Python 类型的可选类型注解。

返回值

新创建并插入的call_function节点。

返回类型

节点

注意

此方法的插入点和类型表达式的规则与 Graph.create_node() 方法的规则相同。

注意

此 API 的向后兼容性得到了保证。

call_method(method_name, args=None, kwargs=None, type_expr=None)[源代码]

Graph中插入一个call_method Node。这个call_method节点表示对args中的第0个元素调用指定的方法。

参数
  • method_name (str) – 指定要应用于 self 参数的方法名。例如,如果 args[0] 是一个表示 TensorNode,则要对该 Tensor 调用 relu() 方法,应将 relu 传递给 method_name

  • args (Optional[Tuple[Argument, ...]]) – 被调用方法所需的位置参数。请注意,这 应该 包括一个 self 参数。

  • kwargs (Optional[Dict[str, Argument]]) – 传递给被调用方法的关键词参数

  • type_expr (Optional[Any]) – 表示该节点输出的 Python 类型的可选类型注解。

返回值

新创建并插入的call_method节点。

返回类型

节点

注意

此方法的插入点和类型表达式的规则与 Graph.create_node() 方法的规则相同。

注意

此 API 的向后兼容性得到了保证。

call_module(module_name, args=None, kwargs=None, type_expr=None)[源代码]

Graph中插入一个call_module Node。这个call_module节点表示对Module层次结构中的某个Moduleforward()函数的调用。

参数
  • module_name (str) – 要调用的 Module 的全限定名称。例如,如果被跟踪的 Module 包含一个名为 foo 的子模块,并且该子模块包含一个名为 bar 的子模块,则应将全限定名称 foo.bar 作为参数传递给 module_name,以调用该模块。

  • args (Optional[Tuple[Argument, ...]]) – 要传递给被调用方法的位置参数。请注意,不应包含 self 参数。

  • kwargs (Optional[Dict[str, Argument]]) – 传递给被调用方法的关键词参数

  • type_expr (Optional[Any]) – 表示该节点输出的 Python 类型的可选类型注解。

返回值

新创建并插入的call_module节点。

返回类型

节点

注意

此方法的插入点和类型表达式的规则与 Graph.create_node() 方法的规则相同。

注意

此 API 的向后兼容性得到了保证。

create_node(op, target, args=None, kwargs=None, name=None, type_expr=None)[源代码]

创建一个Node并将其添加到当前的Graph插入点。请注意,可以通过Graph.inserting_before()Graph.inserting_after()来设置当前的插入点。

参数
  • op (str) – 该节点的指令码。可以是 ‘call_function’, ‘call_method’, ‘get_attr’, ‘call_module’, ‘placeholder’ 或 ‘output’ 中的一个。这些指令码的具体含义在 Graph 的文档字符串中进行了描述。

  • args (Optional[Tuple[Argument, ...]]) – 此节点的参数以元组形式表示。

  • kwargs (Optional[Dict[str, Argument]]) – 此节点的 kwargs 参数

  • name (可选[str]) – 可选的字符串名称,用于标识Node。这将影响在生成的 Python 代码中该值的变量名。

  • type_expr (Optional[Any]) – 表示该节点输出的 Python 类型的可选类型注解。

返回值

新创建并插入的节点。

返回类型

节点

注意

此 API 的向后兼容性得到了保证。

eliminate_dead_code(is_impure_node=None)[源代码]

根据每个节点的用户数量和是否具有任何副作用,删除图中的所有死代码。在调用之前,必须先对图进行拓扑排序。

参数
  • is_impure_node (Optional[Callable[ [Node] , bool]] ) – 返回一个函数

  • None ((节点是否为不纯。如果是)) –

  • 即默认行为)–

  • Node.is_impure使用)–

返回值

该pass是否改变了图。

返回类型

bool

示例:

在删除死代码之前,a = x + 1 中的 a 没有被使用,因此可以从图中移除而不产生任何影响。

def forward(self, x):
    a = x + 1
    return x + self.attr_1

消除死代码后,a = x + 1 被移除,而forward的其他部分保持不变。

def forward(self, x):
    return x + self.attr_1

警告

死代码消除具有一些启发式方法以避免移除有副作用的节点(参见 Node.is_impure),但总体覆盖率很低。因此,除非你知道你的FX图完全由函数操作组成,或者你提供了自己的自定义函数来检测有副作用的节点,否则你应该假设此方法不可靠。

注意

此 API 的向后兼容性得到了保证。

erase_node(to_erase)[源代码]

Graph中移除一个Node。如果Graph中仍然存在该节点的使用者,则会抛出异常。

参数

to_erase (Node) - 需要从 Graph 中删除的节点。

注意

此 API 的向后兼容性得到了保证。

find_nodes(*, op, target=None, sort=True)[源代码]

支持快速查询节点

参数
  • op (str) – 操作的名字

  • target (Optional[Target]) - 节点的目标。在 call_function 中,目标是必需的;而对于其他操作,目标则是可选的。

  • sort (bool) - 是否按照节点在图中出现的顺序来返回节点。

返回值

具有指定操作和目标的节点集合的可迭代对象。

警告

此 API 是实验性的,且支持向下兼容。

get_attr(qualified_name, type_expr=None)[源代码]

在图中插入一个get_attr节点。这个get_attr节点表示从Module层次结构中获取属性。

参数
  • qualified_name (str) – 要检索的属性的完全限定名称。例如,如果被跟踪模块包含一个名为foo的子模块,该子模块又包含一个名为bar的子模块,并且这个子模块有一个名为baz的属性,则应将完全限定名称foo.bar.baz作为参数传递。

  • type_expr (Optional[Any]) – 表示该节点输出的 Python 类型的可选类型注解。

返回值

新创建并插入的get_attr节点。

返回类型

节点

注意

此方法的插入点和类型表达式规则与 Graph.create_node 方法相同。

注意

此 API 的向后兼容性得到了保证。

graph_copy(g, val_map, return_output_node=False)[源代码]

将给定图形中的所有节点复制到 self

参数
  • g (Graph) – 用于复制节点的源图。

  • val_map (Dict[Node, Node]) – 一个字典,用于存储从g中的节点到self中节点的映射。注意,可以传入已经包含某些值的val_map以覆盖这些值的复制。

返回值

如果 g 有一个 output 节点,那么现在与 g 的输出值等价的 self 中的值。否则为 None

返回类型

Optional[Union[Tuple[Any, ...], List[Any], Dict[str, Any], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]

注意

此 API 的向后兼容性得到了保证。

inserting_after(n=None)[源代码]
设置创建节点及其相关方法在图中的插入位置。

当在“with”语句中使用时,这会暂时设置插入点,并在“with”语句结束时恢复:

with g.inserting_after(n):
    ... # inserting after node n
... # insert point restored to what it was previously
g.inserting_after(n) #  set the insert point permanently

参数:

n (Optional[Node]): 要在其之前插入的节点。如果为 None,则在末尾插入。

图形的开始部分。

返回:

一个资源管理器,将在__exit__时恢复插入位置。

注意

此 API 的向后兼容性得到了保证。

inserting_before(n=None)[源代码]
设置创建节点及其相关方法在图中的插入位置。

当在“with”语句中使用时,这会暂时设置插入点,并在“with”语句结束时恢复:

with g.inserting_before(n):
    ... # inserting before node n
... # insert point restored to what it was previously
g.inserting_before(n) #  set the insert point permanently

参数:

n (Optional[Node]): 要在其前插入的节点。若为 None,则在开头插入。

图形的开始部分。

返回:

一个资源管理器,将在__exit__时恢复插入位置。

注意

此 API 的向后兼容性得到了保证。

lint()[源代码]

对此图进行各种检查,确保其结构正确。具体包括:
- 确认节点的拥有权是否正确(由该图所有)
- 验证节点是否按拓扑顺序排列
- 如果此图属于一个GraphModule,则检查目标是否存在于此GraphModule中

注意

此 API 的向后兼容性得到了保证。

node_copy(node, arg_transform=<function Graph.<lambda>>)[源代码]

将一个节点从一个图复制到另一个图中。arg_transform 需要将源节点图中的参数转换为目标图中的相应参数。示例:

# Copying all the nodes in `g` into `new_graph`
g : torch.fx.Graph = ...
new_graph = torch.fx.graph()
value_remap = {}
for node in g.nodes:
    value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n])
参数
  • node (Node) - 需要复制到 self 中的节点。

  • arg_transform (Callable[[Node], Argument]) – 一个函数,用于将 Node 参数从节点的 argskwargs 转换为在 self 中等效的参数。 在最简单的情况下,这个函数会从映射原始图中的 Nodeself 的表中检索值。

返回类型

节点

注意

此 API 的向后兼容性得到了保证。

属性节点: _node_list

获取组成此图的节点列表。

请注意,这个Node列表表示的是一个双向链表。在迭代过程中进行的修改操作(如删除或添加一个节点)是安全的。

返回值

一个由节点组成的双向链表。注意,可以通过调用reversed来切换迭代顺序。

on_generate_code(make_transformer)[源代码]

在生成 Python 代码时注册一个变换函数

参数:
make_transformer ( Callable[[ Optional[TransformCodeFunc] ], TransformCodeFunc] ):

一个用于返回并注册代码转换器的函数。该函数被on_generate_code调用以获取代码转换器。

此函数还将当前注册的代码转换器(如果没有注册则为 None)作为输入参数传递,以防止意外覆盖。这样可以方便地将多个代码转换器串联起来使用。

返回:

一个上下文管理器,在使用 with 语句时,会自动恢复之前注册的代码转换器。

示例:

gm: fx.GraphModule = ...

# This is a code transformer we want to register. This code
# transformer prepends a pdb import and trace statement at the very
# beginning of the generated torch.fx code to allow for manual
# debugging with the PDB library.
def insert_pdb(body):
    return ["import pdb; pdb.set_trace()\n", *body]

# Registers `insert_pdb`, and overwrites the current registered
# code transformer (given by `_` to the lambda):
gm.graph.on_generate_code(
    lambda _: insert_pdb
)

# Or alternatively, registers a code transformer which first
# runs `body` through existing registered transformer, then
# through `insert_pdb`:
gm.graph.on_generate_code(
    lambda current_trans: (
        lambda body: insert_pdb(
            current_trans(body) if current_trans
            else body
        )
    )
)

gm.recompile()
gm(*inputs)  # drops into pdb

此函数还可以用作上下文管理器,能够自动恢复之前注册的代码转换器。

# ... continue from previous example

with gm.graph.on_generate_code(lambda _: insert_pdb):
    # do more stuff with `gm`...
    gm.recompile()
    gm(*inputs)  # drops into pdb

# now previous code transformer is restored (but `gm`'s code with pdb
# remains - that means you can run `gm` with pdb here too, until you
# run next `recompile()`).

警告

此 API 是实验性的,且支持向下兼容。

output(result, type_expr=None)[源代码]

Graph中插入一个output Node。一个output节点表示Python代码中的return语句。result是要返回的结果。

参数
  • result (参数) - 需要返回的值。

  • type_expr (Optional[Any]) – 表示该节点输出的 Python 类型的可选类型注解。

注意

此方法的插入点和类型表达式规则与 Graph.create_node 方法相同。

注意

此 API 的向后兼容性得到了保证。

placeholder(name, type_expr=None, default_value)[源代码]

在图中插入一个占位符节点,它表示函数的输入。

参数
  • name (str) – 输入值的名称。这对应于该 Graph 表示的函数的位置参数。

  • type_expr (Optional[Any]) – 可选的类型注解,表示此节点输出的 Python 类型。在某些情况下(例如,在 TorchScript 编译过程中使用该函数时),这对于正确的代码生成是必需的。

  • default_value (Any) – 该函数参数的默认值。注意:为了允许 None 作为默认值,应将 inspect.Signature.empty 传递给此参数以指定参数没有默认值。

返回类型

节点

注意

此方法的插入点和类型表达式规则与 Graph.create_node 方法相同。

注意

此 API 的向后兼容性得到了保证。

print_tabular()[源代码]

以表格形式打印图的中间表示。请注意,使用此功能需要先安装 tabulate 模块。

注意

此 API 的向后兼容性得到了保证。

process_inputs(*args)[源代码]

处理参数,以便可以将其传递给FX图。

警告

此 API 是实验性的,且支持向下兼容。

process_outputs(out)[源代码]

警告

此 API 是实验性的,且支持向下兼容。

python_code(root_module, *, verbose=False, include_stride=False, include_device=False, colored=False)[源代码]

将这个 Graph 转换为有效的 Python 代码。

参数

root_module (str) – 根模块的名称,用于查找qualified name目标。这通常是‘self’。

返回值

src: 表示对象的 Python 源代码;globals: 包含 src 中全局名称及其引用对象的字典。

返回类型

一个PythonCode对象,包含两个字段。

注意

此 API 的向后兼容性得到了保证。

set_codegen(codegen)[源代码]

警告

此 API 是实验性的,且支持向下兼容。

classtorch.fx.Node(graph, name, op, target, args, kwargs, return_type=None)[源代码]

Node 是表示 Graph 中各个操作的数据结构。通常情况下,Node 表示对各种实体(如运算符、方法和模块)的调用位置。一些例外情况包括指定函数输入和输出的节点。每个 Node 都有一个由其 op 属性定义的函数。对于不同的 op 值,Node 的语义如下:

  • placeholder 用于表示函数输入,name 属性指定该值的名称。target 同样是参数的名称。而args 可以包含:1)无内容,或 2)单个默认参数。 kwargs 不重要。占位符对应于图打印中的函数参数(例如 x)。

  • get_attr 从模块层次结构中获取一个参数。其中,name 是获取结果的变量名,target 是该参数在模块层次结构中的完整限定名称。argskwargs 则无关紧要。

  • call_function 将一个自由函数应用于某些值。其中 name 是要分配给该值的名称,target 是要应用的函数。argskwargs 表示传递给函数的参数,遵循 Python 的调用约定。

  • call_module 使用给定的参数调用模块层次结构中的 forward() 方法。name 与之前相同。 target 是要调用的模块在模块层次结构中的全限定名称,而 argskwargs 表示用于调用该模块的参数(不包括 self 参数)。

  • call_method 在一个值上调用方法。同样地,name 也类似。target 是要应用于 self 参数的方法的字符串名称。argskwargs 表示调用该方法时的所有参数,包括 self 参数

  • output 包含了被追踪函数的返回值,存储在其 args[0] 属性中。这与 Graph 输出中的“return”语句相对应。

注意

此 API 的向后兼容性得到了保证。

属性all_input_nodes:List[Node]

返回所有输入到该节点的节点。这相当于遍历 argskwargs,并只收集其中为节点的值。

返回值

按顺序出现在此 Nodeargskwargs 中的 Nodes 列表。

append(x)[源代码]

在此节点的节点列表中插入 x。相当于执行 self.next.prepend(x)

参数

x (Node) – 需要放置在此节点后的节点。此节点和目标节点必须属于同一张图。

注意

此 API 的向后兼容性得到了保证。

propertyargs:Tuple[Optional[Union[Tuple[Any,...], List[Any], Dict[str, Any], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]],...]

这是传递给Node的参数元组。参数的具体含义由节点的操作码决定。更多详细信息,请参阅Node的文档字符串。

允许对此属性进行赋值。在赋值时,所有使用情况和用户的相关记录会自动更新。

format_node(placeholder_names=None, maybe_return_typename=None)[源代码]

返回self的描述性字符串表示。

此方法可以在不带任何参数的情况下用作调试工具。

此函数还用于 Graph 类的内部 __str__ 方法中。在该图周围 GraphModule 中自动生成的 forward 函数签名由 placeholder_namesmaybe_return_typename 中的字符串组成。placeholder_namesmaybe_return_typename 不应以其他方式使用。

参数
  • placeholder_names (Optional[List[str]]) – 一个存储生成的 forward 函数中占位符格式化字符串的列表。仅供内部使用。

  • maybe_return_typename (Optional[List[str]]) – 一个单元素列表,用于存储表示生成的 forward 函数输出的格式化字符串。仅供内部使用。

返回值
如果 1) 我们使用 format_node 作为内部辅助函数

Graph__str__方法中,如果self是一个占位符节点,则返回None。否则,返回当前节点的描述性字符串表示。

返回类型

str

注意

此 API 的向后兼容性得到了保证。

insert_arg(idx, arg)[源代码]

在指定索引的位置向参数列表中插入一个参数。

参数
  • idx (int) – 在 self.args 中要插入的元素之前的索引位置。

  • arg (参数) – 需要插入到 args 中的新参数值

注意

此 API 的向后兼容性得到了保证。

is_impure()[源代码]

返回该操作是否为不纯操作,即该操作是一个占位符或输出,或者是不纯的 call_function 或 call_module。

返回值

如果是不纯的操作或不是纯函数。

返回类型

bool

警告

此 API 是实验性的,且支持向下兼容。

propertykwargs:Dict[str, Optional[Union[Tuple[Any, ...], List[Any], Dict[str, Any], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload,SymInt, SymBool, SymFloat]]

这是传递给Node的关键词参数字典。参数的具体含义取决于节点的操作码。更多详细信息,请参阅Node的文档字符串。

允许对此属性进行赋值。在赋值时,所有使用情况和用户的相关记录会自动更新。

属性next:Node

返回链表中下一个 Node

返回值

在节点链表中的下一个Node

normalized_arguments(root, arg_types=None, kwarg_types=None, normalize_to_only_use_kwargs=False)[源代码]

返回标准化后的参数给 Python 目标。这意味着 args/kwargs 将与模块或函数的签名匹配,并在 normalize_to_only_use_kwargs 为 true 时按位置顺序返回仅 kwargs 参数。还会填充默认值。不支持仅限位置的参数和可变参数。

支持模块调用。

可能需要arg_typeskwarg_types来区分函数的重载。

参数
  • root (torch.nn.Module) – 用于解析模块目标的基础模块。

  • arg_types (Optional[Tuple[Any]]) – 参数类型的元组

  • kwarg_types (Optional[Dict[str, Any]]) – kwargs 的类型参数字典

  • normalize_to_only_use_kwargs (bool) – 是否将规范改为只使用关键字参数。

返回值

返回 NamedTuple ArgsKwargsPair,如果失败则返回 None

返回类型

Optional[ArgsKwargsPair]

警告

此 API 是实验性的,且支持向下兼容。

prepend(x)[源代码]

在此节点前面的图的节点列表中插入 x。 示例:

Before: p -> self
        bx -> x -> ax
After:  p -> x -> self
        bx -> ax
参数

x (Node) – 需要放置在当前节点之前的一个节点,且必须属于同一张图。

注意

此 API 的向后兼容性得到了保证。

属性prev: Node

返回链表中之前的 Node 节点。

返回值

前一个 Node,在节点的链表中。

replace_all_uses_with(replace_with, delete_user_cb=<function Node.<lambda>>, *, propagate_meta=False)[源代码]

将图中所有的 self 替换为节点 replace_with

参数
  • replace_with (Node) - 用于替换所有 self 使用的节点。

  • delete_user_cb (Callable) – 一个回调函数,用于决定是否应该删除自节点中的特定用户。

  • propagate_meta (bool) – 是否将原始节点的.meta字段上的所有属性复制到替换节点上。出于安全考虑,此操作仅在替换节点尚未具有现有的.meta字段时才有效。

返回值

在此更改中涉及的节点列表。

返回类型

Node

注意

此 API 的向后兼容性得到了保证。

replace_input_with(old_input, new_input)[源代码]

遍历 self 的输入节点,将其中所有的 old_input 替换为 new_input

参数
  • old_input (Node) – 需要被替换的旧输入节点。

  • new_input (Node) – 用于替换 old_input 的新输入节点。

注意

此 API 的向后兼容性得到了保证。

属性stack_trace:Optional[str]

返回在跟踪过程中记录的 Python 堆栈跟踪(如果有)。通常,当使用 fx.Tracer 进行跟踪时,此属性会由 Tracer.create_proxy 方法填充。为了在调试目的下记录堆栈跟踪,请在 Tracer 实例上设置 record_stack_traces = True。默认情况下,当使用 dynamo 进行跟踪时,此属性将由 OutputGraph.create_proxy 填充。

stack_trace会在字符串的末尾包含最内层的帧信息。

update_arg(idx, arg)[源代码]

将现有位置参数的值更新为 arg。更新后,self.args[idx] == arg

参数
  • idx (int) – 在 self.args 中要更新的元素的索引

  • arg (参数) – 新的参数值,将被写入 args

注意

此 API 的向后兼容性得到了保证。

update_kwarg(key, arg)[源代码]

将现有关键字参数的值更新为 arg。更新后,self.kwargs[key] == arg

参数
  • key (str) – 需要更新的元素在 self.kwargs 中的键

  • arg (参数) – 需要写入 kwargs 的新的参数值

注意

此 API 的向后兼容性得到了保证。

torch.fx.Tracer(autowrap_modules=(math,), autowrap_functions=())[源代码]

Tracer 是实现 torch.fx.symbolic_trace 符号跟踪功能的类。调用 symbolic_trace(m) 等同于调用 Tracer().trace(m)

可以通过子类化 Tracer 来覆盖跟踪过程中的各种行为。可以覆盖的行为在该类方法的文档字符串中有详细说明。

注意

此 API 的向后兼容性得到了保证。

call_module(m, forward, args, kwargs)[源代码]

定义此Tracer在遇到nn.Module实例的调用时应如何处理的方法。

默认情况下,会通过is_leaf_module检查被调用的模块是否为叶模块。如果是,则在Graph中生成一个引用mcall_module节点。否则,正常调用该Module并跟踪其forward函数中的操作。

你可以重写此方法来创建嵌套的追踪 GraphModules,或在跨越 Module 边界时实现其他所需的行为。

参数
  • m (Module) – 要进行调用的模块

  • forward (Callable) – Module 中要调用的 forward() 方法

  • args (元组) – 模块调用位置的参数

  • kwargs (Dict) – 模块调用位置的 kwargs 参数

返回值

模块调用的返回值。如果是call_module节点发出的情况,那么返回的是一个Proxy值。否则,返回的就是从Module调用中得到的实际值。

返回类型

Any

注意

此 API 的向后兼容性得到了保证。

create_arg(a)[源代码]

一种指定在为 中节点参数准备值时进行追踪行为的方法。

默认的行为包括:

  1. 遍历各种集合类型(如元组、列表、字典),并对其内的元素递归调用create_args

  2. 给定一个Proxy对象,返回底层IR Node的引用。

  3. 对于给定的非代理张量对象,为不同情形生成IR代码:

    • 对于一个参数,生成一个指向该参数的get_attr节点

    • 对于非参数张量,将其存储在一个特殊的属性中,并引用该属性。

可以重写此方法以支持更多类型的对象。

参数

a (任意) – 在 Graph 中作为 Argument 发射的值。

返回值

将值 a 转换为适当的 Argument

返回类型

参数

注意

此 API 的向后兼容性得到了保证。

create_args_for_root(root_fn, is_module, concrete_args=None)[源代码]

创建与root模块签名对应的placeholder节点。此方法会分析根模块的签名并相应地生成这些节点,同时也支持*args**kwargs

警告

此 API 是实验性的,且支持向下兼容。

create_node(kind, target, args, kwargs, name=None, type_expr=None)

根据目标、参数 args、关键字参数 kwargs 和名称来插入一个图节点。

可以重写此方法来执行额外的检查、验证或修改用于节点创建的值。例如,可能希望禁止记录就地操作。

注意

此 API 的向后兼容性得到了保证。

返回类型

节点

create_proxy(kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None)

根据给定的参数创建一个节点,然后将其封装在代理对象中返回。

如果 kind = ‘placeholder’,那么我们创建一个表示函数参数的 Node。如果我们需要为默认参数编码,则使用 args 元组。args 对于 placeholder Nodes 而言则为空。

注意

此 API 的向后兼容性得到了保证。

get_fresh_qualname(prefix)[源代码]

为前缀获取一个新名称并返回,以确保它不会与图中现有的属性冲突。

注意

此 API 的向后兼容性得到了保证。

返回类型

str

getattr(attr, attr_val, parameter_proxy_cache)[源代码]

定义此Tracer在调用nn.Module实例时处理getattr方法的行为。

默认情况下,会为属性返回一个代理值。此外,它还会将该代理值存储在 parameter_proxy_cache 中,这样未来的调用可以重复使用这个代理,而无需重新创建一个新的。

此方法可以被重写,以便在查询参数时不要返回代理。

参数
  • attr (str) – 查询的属性名

  • attr_val (Any) – 表示属性的值

  • parameter_proxy_cache (Dict[str, Any]) – 一个缓存,将属性名映射到代理对象。

返回值

getattr 调用返回的值。

警告

此 API 是实验性的,且支持向下兼容。

is_leaf_module(m, module_qualified_name)[源代码]

一种方法,用于指定给定的 nn.Module 是否为“叶子”模块。

叶模块是出现在IR中的基本单元,通过call_module调用进行引用。默认情况下,PyTorch标准库命名空间(torch.nn)中的模块都是叶模块。除非另有指定,否则所有其他模块都将被追踪,并记录其内部的操作。

参数
  • m (Module) – 需要查询的模块

  • module_qualified_name (str) – 表示此模块路径的字符串。例如,如果你有一个模块层次结构,其中子模块foo包含子模块bar,而bar又包含子模块baz,那么该模块将在此处以全限定名为foo.bar.baz的形式出现。

返回类型

bool

注意

此 API 的向后兼容性得到了保证。

iter(obj)
在迭代代理对象时被调用,例如:

在控制流中使用时,通常我们不知道该如何操作,因为不知道代理对象的值。但是,一个自定义追踪器可以使用 create_node 向图节点添加更多详情,并且可以选择返回一个迭代器。

注意

此 API 的向后兼容性得到了保证。

返回类型

迭代器

keys(obj)
当代理对象调用 keys() 方法时触发。

这是在代理上调用 ** 时发生的情况。它应该返回一个迭代器,并且 ** 应该在你的自定义追踪器中正常工作。

注意

此 API 的向后兼容性得到了保证。

返回类型

Any

path_of_module(mod)[源代码]

这是一个辅助方法,用于在root的模块层次结构中查找mod的全名。例如,如果root有一个名为foo的子模块,并且该foo子模块还有一个名为bar的子模块,则将bar传递给此函数将会返回字符串“foo.bar”。

参数

mod (str) – 用于检索全限定名称的 Module

返回类型

str

注意

此 API 的向后兼容性得到了保证。

代理(节点)

注意

此 API 的向后兼容性得到了保证。

返回类型

Proxy

to_bool(obj)
当将代理对象转换为布尔值时调用,例如

在控制流中使用时,通常我们不知道该如何操作,因为不知道代理对象的值。但是,一个自定义追踪器可以使用 create_node 向图节点添加更多相关信息,并且可以选择返回一个具体的值。

注意

此 API 的向后兼容性得到了保证。

返回类型

bool

trace(root, concrete_args=None)[源代码]

追踪 root 并返回对应的 FX Graph 表示。这里的 root 可以是一个 nn.Module 实例或一个 Python 函数。

注意,在此调用之后,self.root 可能与传入的 root 不同。例如,当一个自由函数被传递给 trace() 时,我们会创建一个 nn.Module 实例作为根节点,并添加嵌入常量。

参数
  • root (Union[Module, Callable]) – 可以是一个 Module 或者一个需要被追踪的函数。此参数的向后兼容性已得到保证。

  • concrete_args (Optional[Dict[str, any]]) – 具体参数,不应被视为代理。此参数是实验性的,不保证向后兼容性。

返回值

表示传入的 root 语义的 Graph

返回类型

注意

此 API 的向后兼容性得到了保证。

torch.fx.Proxy(node, tracer=None)[源代码]

Proxy 对象是 Node 的包装器,在符号跟踪过程中流经程序,并将它们接触的所有操作(如 torch 函数调用、方法调用和运算符)记录到不断增长的 FX 图中。

如果你在进行图变换,可以为原始的 Node 包装一个自定义的 Proxy 方法,以便使用重载的操作符向 Graph 添加额外的内容。

Proxy 对象不能被迭代。换句话说,如果在循环中或作为*args/**kwargs 函数参数使用 Proxy,符号追踪器会抛出错误。

有两种主要的方法来解决这个问题:1. 将不可追踪的逻辑提取到一个顶级函数中,并使用 fx.wrap 包围它。2. 如果控制流是静态的(即循环迭代次数基于某些超参数),代码可以保持在原来的位置并进行重构,使其类似于:

for i in range(self.some_hyperparameter):
    indexed_item = proxied_value[i]

了解更多关于代理内部的详细信息,请参阅torch/fx/README.md中的“Proxy”部分。

注意

此 API 的向后兼容性得到了保证。

torch.fx.Interpreter(module, garbage_collect_values=True, graph=None)[源代码]

解释器按节点逐一执行FX图。这种模式很有用,可以用于编写代码转换和进行分析等任务。

可以重写 Interpreter 类中的方法来自定义执行行为。根据调用层次结构,可覆盖的方法如下:

run()
    +-- run_node
        +-- placeholder()
        +-- get_attr()
        +-- call_function()
        +-- call_method()
        +-- call_module()
        +-- output()

示例

假设我们想将所有 torch.neg 实例替换为 torch.sigmoid,反之亦然(包括它们的 Tensor 方法等价物)。我们可以像这样子类化 Interpreter:

class NegSigmSwapInterpreter(Interpreter):
    def call_function(self, target : Target,
                      args : Tuple, kwargs : Dict) -> Any:
        if target == torch.sigmoid:
            return torch.neg(*args, **kwargs)
        return super().call_function(n)

    def call_method(self, target : Target,
                    args : Tuple, kwargs : Dict) -> Any:
        if target == 'neg':
            call_self, *args_tail = args
            return call_self.sigmoid(*args_tail, **kwargs)
        return super().call_method(n)

def fn(x):
    return torch.sigmoid(x).neg()

gm = torch.fx.symbolic_trace(fn)
input = torch.randn(3, 4)
result = NegSigmSwapInterpreter(gm).run(input)
torch.testing.assert_close(result, torch.neg(input).sigmoid())
参数
  • module (torch.nn.Module) – 需要执行的模块

  • garbage_collect_values (bool) – 是否在模块执行过程中删除值的最后一次使用后进行清理。这确保了执行过程中的最优内存使用。可以通过禁用此功能来检查所有中间值,例如查看 Interpreter.env 属性。

  • graph (可选[Graph]) – 如果提供了此参数,解释器将使用该图进行执行,而不是module.graph,并利用提供的module来满足状态请求。

注意

此 API 的向后兼容性得到了保证。

boxed_run(args_list)[源代码]

通过解释执行模块并返回结果。这采用“带框”调用约定:传递一个参数列表,该列表会被解释器清空,从而确保输入张量能被及时释放。

注意

此 API 的向后兼容性得到了保证。

call_function(target, args, kwargs)[源代码]

执行 call_function 节点并返回结果。

参数
  • target (Target) - 此节点的调用目标。关于语义的详细信息,请参阅Node

  • args (元组) – 本次调用的位置参数组成的元组

  • kwargs (Dict) – 此次调用的关键词参数字典

返回类型

Any

返回值

Any: 函数调用返回的值

注意

此 API 的向后兼容性得到了保证。

call_method(target, args, kwargs)[源代码]

执行 call_method 节点并返回结果。

参数
  • target (Target) - 此节点的调用目标。关于语义的详细信息,请参阅Node

  • args (元组) – 本次调用的位置参数组成的元组

  • kwargs (Dict) – 此次调用的关键词参数字典

返回类型

Any

返回值

Any:方法调用返回的值

注意

此 API 的向后兼容性得到了保证。

call_module(target, args, kwargs)[源代码]

执行一个call_module节点并返回结果。

参数
  • target (Target) - 此节点的调用目标。关于语义的详细信息,请参阅Node

  • args (元组) – 本次调用的位置参数组成的元组

  • kwargs (Dict) – 此次调用的关键词参数字典

返回类型

Any

返回值

Any: 模块调用返回的值

注意

此 API 的向后兼容性得到了保证。

fetch_args_kwargs_from_env(n)[源代码]

从当前执行环境获取节点 nargskwargs 的具体值。

参数

n (Node) – 获取 argskwargs 的节点。

返回值

带有具体值的 argskwargs,用于 n

返回类型

元组[元组, 字典]

注意

此 API 的向后兼容性得到了保证。

fetch_attr(target)[源代码]

self.moduleModule层次结构中获取一个属性。

参数

target (str) – 需要获取的目标属性的完整名称

返回值

该属性的值。

返回类型

任何

注意

此 API 的向后兼容性得到了保证。

get_attr(target, args, kwargs)[源代码]

执行一个get_attr节点,从中检索self.moduleModule层次结构中的属性值。

参数
  • target (Target) - 此节点的调用目标。关于语义的详细信息,请参阅Node

  • args (元组) – 本次调用的位置参数组成的元组

  • kwargs (Dict) – 此次调用的关键词参数字典

返回值

检索到的属性的值

返回类型

任何

注意

此 API 的向后兼容性得到了保证。

map_nodes_to_values(args, n)[源代码]

递归遍历 args,并在当前执行环境中为每个 Node 查找具体的值。

参数
  • args (参数) – 用于查找具体值的数据结构

  • n (Node) — args 所属的节点。这仅用于错误报告。

返回类型

Optional[Union[Tuple[Any, ...], List[Any], Dict[str, Any], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]

注意

此 API 的向后兼容性得到了保证。

output(target, args, kwargs)[源代码]

执行一个output节点。实际上是获取该节点引用的值并返回它。

参数
  • target (Target) - 此节点的调用目标。关于语义的详细信息,请参阅Node

  • args (元组) – 本次调用的位置参数组成的元组

  • kwargs (Dict) – 此次调用的关键词参数字典

返回值

输出节点所引用的返回值

返回类型

任何

注意

此 API 的向后兼容性得到了保证。

placeholder(target, args, kwargs)[源代码]

执行一个placeholder节点。需要注意的是,这是有状态的操作: Interpreter维护了一个内部迭代器,用于遍历传递给run方法的参数,并且此方法返回该迭代器的下一个值。

参数
  • target (Target) - 此节点的调用目标。关于语义的详细信息,请参阅Node

  • args (元组) – 本次调用的位置参数组成的元组

  • kwargs (Dict) – 此次调用的关键词参数字典

返回值

获取的参数值。

返回类型

任何

注意

此 API 的向后兼容性得到了保证。

run(*args, initial_env=None, enable_io_processing=True)[源代码]

通过解释运行模块并返回结果。

参数
  • *args - 按位置顺序传递给模块的参数

  • initial_env (Optional[Dict[Node, Any]]) – 可选的执行起始环境。这是一个将Node映射到任意值的字典。例如,可以预先填充某些Nodes的结果,以便在解释器中仅进行部分评估。

  • enable_io_processing (bool) – 如果为 true,则在使用输入和输出之前,先用图的 process_inputs 和 process_outputs 函数进行处理。

返回值

执行模块返回的值

返回类型

任何

注意

此 API 的向后兼容性得到了保证。

run_node(n)[源代码]

运行特定节点 n 并返回结果。根据 node.op 的值,调用相应的函数,如 placeholder、get_attr、call_function、call_method、call_module 或 output。

参数

n (Node) — 需要执行的节点

返回值

n执行的结果

返回类型

任何

注意

此 API 的向后兼容性得到了保证。

torch.fx.Transformer(module)[源代码]

Transformer 是一种特殊的解释器,它生成一个新的 Module。它提供了一个 transform() 方法来返回转换后的 Module。与需要参数的 Interpreter 不同,Transformer 运行时不需要任何参数,并且完全以符号方式进行工作。

示例

假设我们想将所有 torch.neg 实例替换为 torch.sigmoid,反之亦然(包括它们的 Tensor 方法等价物)。我们可以像这样子类化 Transformer:

class NegSigmSwapXformer(Transformer):
    def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
        if target == torch.sigmoid:
            return torch.neg(*args, **kwargs)
        return super().call_function(n)

    def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
        if target == 'neg':
            call_self, *args_tail = args
            return call_self.sigmoid(*args_tail, **kwargs)
        return super().call_method(n)

def fn(x):
    return torch.sigmoid(x).neg()

gm = torch.fx.symbolic_trace(fn)

transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform()
input = torch.randn(3, 4)
torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
参数

module (GraphModule) - 要进行转换的 Module

注意

此 API 的向后兼容性得到了保证。

call_function(target, args, kwargs)[源代码]

注意

此 API 的向后兼容性得到了保证。

返回类型

Any

call_module(target, args, kwargs)[源代码]

注意

此 API 的向后兼容性得到了保证。

返回类型

Any

get_attr(target, args, kwargs)[源代码]

执行一个get_attr节点。在Transformer中,会重写此操作以在输出图中插入一个新的get_attr节点。

参数
  • target (Target) - 此节点的调用目标。关于语义的详细信息,请参阅Node

  • args (元组) – 本次调用的位置参数组成的元组

  • kwargs (Dict) – 此次调用的关键词参数字典

返回类型

Proxy

注意

此 API 的向后兼容性得到了保证。

placeholder(target, args, kwargs)[源代码]

执行一个placeholder节点。在Transformer中,会重新定义此操作以在输出图中插入一个新的placeholder

参数
  • target (Target) - 此节点的调用目标。关于语义的详细信息,请参阅Node

  • args (元组) – 本次调用的位置参数组成的元组

  • kwargs (Dict) – 此次调用的关键词参数字典

返回类型

Proxy

注意

此 API 的向后兼容性得到了保证。

transform()[源代码]

self.module 转换,并返回转换后的 GraphModule

注意

此 API 的向后兼容性得到了保证。

返回类型

GraphModule

torch.fx.replace_pattern(gm, pattern, replacement)[源代码]

在图模块(gm)的图中,找到所有可能的非重叠操作符及其数据依赖关系集合(pattern),然后将每个匹配到的子图替换为另一个子图(replacement)。

参数
返回值

表示在原始图中pattern匹配位置的Match对象列表。如果不存在匹配项,该列表将为空。Match定义如下:

class Match(NamedTuple):
    # Node from which the match was found
    anchor: Node
    # Maps nodes in the pattern subgraph to nodes in the larger graph
    nodes_map: Dict[Node, Node]
返回类型

匹配列表

示例:

import torch
from torch.fx import symbolic_trace, subgraph_rewriter

class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x, w1, w2):
        m1 = torch.cat([w1, w2]).sum()
        m2 = torch.cat([w1, w2]).sum()
        return x + torch.max(m1) + torch.max(m2)

def pattern(w1, w2):
    return torch.cat([w1, w2]).sum()

def replacement(w1, w2):
    return torch.stack([w1, w2])

traced_module = symbolic_trace(M())

subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)

上述代码将在traced_moduleforward方法中首先匹配pattern。模式匹配是基于使用定义关系进行的,而不是节点名称。例如,如果你在pattern中有p = torch.cat([a, b]),你可以在原始的forward函数中找到与之对应的代码m = torch.cat([a, b]),即使变量名不同(p vs m)。

pattern 中,return 语句仅根据其值进行匹配;它可能与更大的图中的 return 语句匹配,也可能不匹配。换句话说,模式不必延伸到更大图的末尾。

当模式匹配成功时,它将从较大的函数中移除,并被replacement 替换。如果较大函数中有多个 pattern 匹配项,则每个非重叠的匹配项都将被替换。对于重叠的匹配项,将在一组重叠的匹配项中找到的第一个匹配项进行替换。(这里的“第一个”是指根据节点使用定义关系的拓扑排序中的第一个节点。通常情况下,第一个节点是直接出现在 self 之后的参数,而最后一个节点则是函数返回的内容。)

需要注意的是,pattern 可调用函数中的参数必须在该可调用函数中使用,而 replacement 可调用函数的参数必须与模式匹配。这就是为什么在上述代码块中,forward 函数有参数 x, w1, w2,但 pattern 函数只有参数 w1, w2。由于 pattern 不使用 x 参数,因此不应将其指定为参数。作为第二条规则的一个例子,请考虑替换操作。

def pattern(x, y):
    return torch.neg(x) + torch.relu(y)

使用

def replacement(x, y):
    return torch.relu(x)

在这种情况下,replacement 需要与 pattern 相同数量的参数(即 xy),即使 y 参数在 replacement 中没有被使用。

调用 subgraph_rewriter.replace_pattern 之后,生成的 Python 代码如下所示:

def forward(self, x, w1, w2):
    stack_1 = torch.stack([w1, w2])
    sum_1 = stack_1.sum()
    stack_2 = torch.stack([w1, w2])
    sum_2 = stack_2.sum()
    max_1 = torch.max(sum_1)
    add_1 = x + max_1
    max_2 = torch.max(sum_2)
    add_2 = add_1 + max_2
    return add_2

注意

此 API 的向后兼容性得到了保证。

本页目录