torch.fx
概述
FX 是一个开发工具包,用于转换 nn.Module
实例。它包含三个主要组件:符号追踪器、中间表示和 Python 代码生成。以下是这些组件的示例演示:
import torch # Simple module for demonstration class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) def forward(self, x): return self.linear(x + self.param).clamp(min=0.0, max=1.0) module = MyModule() from torch.fx import symbolic_trace # Symbolic tracing frontend - captures the semantics of the module symbolic_traced : torch.fx.GraphModule = symbolic_trace(module) # High-level intermediate representation (IR) - Graph representation print(symbolic_traced.graph) """ graph(): %x : [num_users=1] = placeholder[target=x] %param : [num_users=1] = get_attr[target=param] %add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {}) %linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {}) %clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0}) return clamp """ # Code generation - valid Python code print(symbolic_traced.code) """ def forward(self, x): param = self.param add = x + param; x = param = None linear = self.linear(add); add = None clamp = linear.clamp(min = 0.0, max = 1.0); linear = None return clamp """
符号跟踪器 对 Python 代码执行“符号执行”。它将假值(称为代理 Proxies)传递给代码。对这些代理的操作会被记录下来。有关符号跟踪的更多信息可以在 symbolic_trace()
和 Tracer
文档中找到。
中间表示(Intermediate Representation,简称IR)是符号跟踪期间记录的操作容器。它由一系列节点组成,这些节点代表函数输入、调用站点(包括函数、方法或torch.nn.Module
实例),以及返回值。关于IR的更多信息可以在 Graph
文档中找到。IR是应用变换的基础格式。
Python代码生成 是使FX成为Python到Python(或模块到模块)转换工具包的关键。对于每个图IR,我们可以创建与其语义匹配的有效Python代码。此功能封装在GraphModule
中,它是持有Graph
以及从该图生成的forward
方法的torch.nn.Module
实例。
这些组件组成的管道(符号跟踪 -> 中间表示 -> 转换 -> Python 代码生成)共同构成了 FX 的 Python 到 Python 转换管道。此外,这些组件可以单独使用。例如,符号跟踪可以在隔离状态下用于捕获代码的一种形式以进行分析(而非转换)。代码生成可用于根据配置文件等程序化地生成模型。FX 具有许多用途!
几个示例转换可以在示例仓库中找到。
编写转换
什么是FX变换?本质上,它是一个这样的函数。
import torch import torch.fx def transform(m: nn.Module, tracer_class : type = torch.fx.Tracer) -> torch.nn.Module: # Step 1: Acquire a Graph representing the code in `m` # NOTE: torch.fx.symbolic_trace is a wrapper around a call to # fx.Tracer.trace and constructing a GraphModule. We'll # split that out in our transform to allow the caller to # customize tracing behavior. graph : torch.fx.Graph = tracer_class().trace(m) # Step 2: Modify this Graph or create a new one graph = ... # Step 3: Construct a Module to return return torch.fx.GraphModule(m, graph)
你的转换会接收一个torch.nn.Module
,从中获取一个Graph
,进行一些修改,并返回一个新的torch.nn.Module
。你应该将你的FX转换返回的torch.nn.Module
视为与常规的torch.nn.Module
相同——你可以将其传递给另一个FX转换,可以传递给TorchScript,或者直接运行它。确保你的FX转换的输入和输出是一个torch.nn.Module
将使其具有组合性。
注意
你也可以选择修改现有的GraphModule
,而不需要创建新的,具体方法如下:
import torch import torch.fx def transform(m : nn.Module) -> nn.Module: gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m) # Modify gm.graph # <...> # Recompile the forward() method of `gm` from its Graph gm.recompile() return gm
请注意,你必须调用GraphModule.recompile()
,以确保生成的forward()
方法与修改后的Graph
保持同步。
鉴于你传入了一个已经被追踪为 torch.nn.Module
的 Graph
,现在有两种主要的方法可以用来构建一个新的 Graph
。
图的基础简介
关于图的语义的完整处理可以在 Graph
文档中找到,但这里我们将介绍基础知识。一个 Graph
是一种数据结构,表示在 GraphModule
上的方法。所需的信息包括:
-
这个方法的输入参数是什么?
-
方法内执行了哪些操作?
-
该方法的返回值是什么?
这三个概念都使用Node
实例来表示。让我们通过一个简短的例子来看看具体含义:
import torch import torch.fx class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) def forward(self, x): return torch.topk(torch.sum( self.linear(x + self.linear.weight).relu(), dim=-1), 3) m = MyModule() gm = torch.fx.symbolic_trace(m) gm.graph.print_tabular()
在这里,我们为了演示目的定义了一个名为MyModule
的模块,对其进行实例化和符号跟踪,然后调用Graph.print_tabular()
方法来打印显示该Graph
节点的表格。
操作码
姓名
目标
参数
kwargs
占位符
x
x
()
{}
获取属性
线性权重
linear权重
()
{}
调用函数
添加_1
add(内置函数)
(x, 线性权重)
{}
调用模块
线性_1
线性
(add_1,)
{}
调用方法
ReLU_1
ReLU
(linear_1)
{}
调用函数
总述_1
<内置求和方法>
relu_1
{"dim": -1}
调用函数
topk_1
<内置方法 topk >
sum(1, 3)
{}
输出结果
输出结果
输出结果
(topk_1,)
{}
-
方法的输入是什么?在FX中,方法输入通过特殊的
placeholder
节点指定。在这种情况下,我们有一个具有target
为x
的单个placeholder
节点,这意味着我们有一个名为x的单一(非self)参数。 -
方法内的操作有哪些?
get_attr
、call_function
、call_module
和call_method
节点代表了方法中的操作。所有这些节点的完整语义解释可以在Node
文档中找到。 -
方法的返回值是什么?在
Graph
中,返回值由一个特殊的output
节点指定。
既然我们现在掌握了代码在FX中的基本表示方法,就可以开始探索如何编辑一个Graph
了。
图形操作
直接图 Manipulation
构建新的Graph
的一种方法是直接操作现有的图。为此,我们可以获取符号跟踪得到的Graph
并进行修改。例如,假设我们希望将torch.add()
调用替换为torch.mul()
调用。
import torch import torch.fx # Sample module class M(torch.nn.Module): def forward(self, x, y): return torch.add(x, y) def transform(m: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module: graph : fx.Graph = tracer_class().trace(m) # FX represents its Graph as an ordered list of # nodes, so we can iterate through them. for node in graph.nodes: # Checks if we're calling a function (i.e: # torch.add) if node.op == 'call_function': # The target attribute is the function # that call_function calls. if node.target == torch.add: node.target = torch.mul graph.lint() # Does some checks to make sure the # Graph is well-formed. return fx.GraphModule(m, graph)
我们还可以执行更复杂的 Graph
重写操作,例如删除或添加节点。为了帮助这些转换,FX 提供了一些用于图变换的实用函数,这些函数可以在Graph
文档中找到。下面提供了一个使用这些 API 添加 torch.relu()
调用的例子。
# Specifies the insertion point. Any nodes added to the # Graph within this scope will be inserted after `node` with traced.graph.inserting_after(node): # Insert a new `call_function` node calling `torch.relu` new_node = traced.graph.call_function( torch.relu, args=(node,)) # We want all places that used the value of `node` to # now use that value after the `relu` call we've added. # We use the `replace_all_uses_with` API to do this. node.replace_all_uses_with(new_node)
对于仅包含替换的简单转换,你还可以使用子图重写器。
利用replace_pattern()进行子图重写
FX 还在直接图操作的基础上提供了另一层自动化。replace_pattern()
API 实质上是一个用于编辑Graph
的“查找/替换”工具。它允许你指定一个pattern
和replacement
函数,并遍历这些函数,找到在pattern
图中的操作组实例,并用replacement
图的副本替换它们。这可以帮助大大简化繁琐的图操作代码,在转换变得更加复杂时,这种代码可能会变得难以管理。
代理/回溯
另一种操作 Graph
的方式是重用符号跟踪中使用的 Proxy
机制。例如,假设我们想编写一个将 PyTorch 函数分解为更小的操作的转换器。它会将每个 F.relu(x)
调用转换为 (x > 0) * x
。一种可能是执行必要的图重写,在 F.relu
后插入比较和乘法操作,然后清理原始的 F.relu
。然而,我们可以通过使用 Proxy
对象自动记录操作到 Graph
中来自动化这个过程。
要使用此方法,我们编写常规的 PyTorch 代码来插入所需的操作,并用 Proxy
对象作为参数调用该代码。这些 Proxy
对象会捕获在其上执行的操作并将它们追加到 Graph
中。
# Note that this decomposition rule can be read as regular Python def relu_decomposition(x): return (x > 0) * x decomposition_rules = {} decomposition_rules[F.relu] = relu_decomposition def decompose(model: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module: """ Decompose `model` into smaller constituent operations. Currently,this only supports decomposing ReLU into its mathematical definition: (x > 0) * x """ graph : fx.Graph = tracer_class().trace(model) new_graph = fx.Graph() env = {} tracer = torch.fx.proxy.GraphAppendingTracer(new_graph) for node in graph.nodes: if node.op == 'call_function' and node.target in decomposition_rules: # By wrapping the arguments with proxies, # we can dispatch to the appropriate # decomposition rule and implicitly add it # to the Graph by symbolically tracing it. proxy_args = [ fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args] output_proxy = decomposition_rules[node.target](*proxy_args) # Operations on `Proxy` always yield new `Proxy`s, and the # return value of our decomposition rule is no exception. # We need to extract the underlying `Node` from the `Proxy` # to use it in subsequent iterations of this transform. new_node = output_proxy.node env[node.name] = new_node else: # Default case: we don't have a decomposition rule for this # node, so just copy the node over into the new graph. new_node = new_graph.node_copy(node, lambda x: env[x.name]) env[node.name] = new_node return fx.GraphModule(model, new_graph)
除了避免显式图操作外,使用 Proxy
还可以让你将重写规则指定为原生 Python 代码。对于需要大量重写规则的转换(例如 vmap 或 grad),这通常可以提高规则的可读性和可维护性。需要注意的是,在调用 Proxy
时,我们还传递了一个指向底层变量图的 tracer。这样做是为了在图中的操作是 n 元(例如 add 是一个二元运算符)的情况下,Proxy
调用不会创建多个图追踪器实例,这可能会导致意外的运行时错误。我们特别推荐在这种情况下使用 Proxy
,尤其是在不能安全地假设底层操作符为一元的情况下。
一个使用Proxy
进行Graph
操作的示例可以在这里找到。
解释器模式
FX 中一个有用的代码组织模式是遍历 Graph
中的所有 Node
并执行它们。这种模式可以用于多种用途,例如通过运行时分析流经图的值或使用 Proxy
重新跟踪来转换代码。例如,假设我们希望运行一个 GraphModule
并在运行时记录我们在节点上看到的 torch.Tensor
的形状和数据类型属性。这可能看起来像:
import torch import torch.fx from torch.fx.node import Node from typing import Dict class ShapeProp: """ Shape propagation. This class takes a `GraphModule`. Then, its `propagate` method executes the `GraphModule` node-by-node with the given arguments. As each operation executes, the ShapeProp class stores away the shape and element type for the output values of each operation on the `shape` and `dtype` attributes of the operation's `Node`. """ def __init__(self, mod): self.mod = mod self.graph = mod.graph self.modules = dict(self.mod.named_modules()) def propagate(self, *args): args_iter = iter(args) env : Dict[str, Node] = {} def load_arg(a): return torch.fx.graph.map_arg(a, lambda n: env[n.name]) def fetch_attr(target : str): target_atoms = target.split('.') attr_itr = self.mod for i, atom in enumerate(target_atoms): if not hasattr(attr_itr, atom): raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}") attr_itr = getattr(attr_itr, atom) return attr_itr for node in self.graph.nodes: if node.op == 'placeholder': result = next(args_iter) elif node.op == 'get_attr': result = fetch_attr(node.target) elif node.op == 'call_function': result = node.target(*load_arg(node.args), **load_arg(node.kwargs)) elif node.op == 'call_method': self_obj, *args = load_arg(node.args) kwargs = load_arg(node.kwargs) result = getattr(self_obj, node.target)(*args, **kwargs) elif node.op == 'call_module': result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs)) # This is the only code specific to shape propagation. # you can delete this `if` branch and this becomes # a generic GraphModule interpreter. if isinstance(result, torch.Tensor): node.shape = result.shape node.dtype = result.dtype env[node.name] = result return load_arg(self.graph.result)
如你所见,一个完整的FX解释器并不复杂,但非常有用。为了方便使用这一模式,我们提供了Interpreter
类,该类通过方法重写封装了上述逻辑,并允许覆盖解释器执行的某些方面。
除了执行操作之外,我们还可以通过将 Proxy
值传递给解释器来生成新的 Graph。同样地,我们提供了 Transformer
类来涵盖这种模式。Transformer
的行为类似于 Interpreter
,但你不会调用 run
方法来获取 Module 的具体输出值,而是会调用Transformer.transform()
方法以返回一个新的 GraphModule
,该模块经过了你安装的任何转换规则(作为重写的方法)。
调试
介绍
在编写转换的过程中,我们的代码往往不会完全正确。在这种情况下,我们需要进行一些调试工作。关键是倒着来:首先,检查调用生成模块的结果以验证其正确性或错误;然后,检查和调试生成的代码;最后,调试导致生成代码的整个转换过程。
如果您不熟悉调试器,请参见辅助部分中的“可用的调试器”。
验证模块的准确性
由于大多数深度学习模块的输出是由浮点数 torch.Tensor
实例组成的,因此判断两个 torch.nn.Module
的结果是否相等并不像简单的相等性检查那样直接。为了说明这一点,我们来看一个例子:
import torch import torch.fx import torchvision.models as models def transform(m : torch.nn.Module) -> torch.nn.Module: gm = torch.fx.symbolic_trace(m) # Imagine we're doing some transforms here # <...> gm.recompile() return gm resnet18 = models.resnet18() transformed_resnet18 = transform(resnet18) input_image = torch.randn(5, 3, 224, 224) assert resnet18(input_image) == transformed_resnet18(input_image) """ RuntimeError: Boolean value of Tensor with more than one value is ambiguous """
在这里,我们尝试使用==
运算符检查两个深度学习模型的值是否相等。然而,由于该运算符返回的是张量而不是布尔值,并且浮点数比较应考虑误差范围(或epsilon)以处理浮点运算的非交换性问题(详情请参阅此处),因此这种做法是不明确的。我们可以改用torch.allclose()
,它会根据相对和绝对容差阈值进行近似比较:
assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image))
这是我们工具箱中的第一个工具,用于检查转换后的模块是否按预期运行,并与参考实现进行比较。
调试生成的代码
由于FX在GraphModule
上生成了forward()
函数,使用传统的调试方法(如print
语句或pdb
)并不那么直接。幸运的是,我们有一些技术可以用来调试生成的代码。
使用 pdb
调用pdb
来进入正在运行的程序。尽管表示Graph
的代码不在任何源文件中,我们仍然可以在前向传递被调用时手动使用pdb
进行调试。
import torch import torch.fx import torchvision.models as models def my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module: graph = tracer_class().trace(inp) # Transformation logic here # <...> # Return new Module return fx.GraphModule(inp, graph) my_module = models.resnet18() my_module_transformed = my_pass(my_module) input_value = torch.randn(5, 3, 224, 224) # When this line is executed at runtime, we will be dropped into an # interactive `pdb` prompt. We can use the `step` or `s` command to # step into the execution of the next line import pdb; pdb.set_trace() my_module_transformed(input_value)
打印生成的代码
如果你想多次运行相同的代码,使用 pdb
跳到正确的代码可能会有点繁琐。在这种情况下,你可以将生成的 forward
传递复制粘贴到你的代码中,并从那里进行检查。
# Assume that `traced` is a GraphModule that has undergone some # number of transforms # Copy this code for later print(traced) # Print the code generated from symbolic tracing. This outputs: """ def forward(self, y): x = self.x add_1 = x + y; x = y = None return add_1 """ # Subclass the original Module class SubclassM(M): def __init__(self): super().__init__() # Paste the generated `forward` function (the one we printed and # copied above) here def forward(self, y): x = self.x add_1 = x + y; x = y = None return add_1 # Create an instance of the original, untraced Module. Then, create an # instance of the Module with the copied `forward` function. We can # now compare the output of both the original and the traced version. pre_trace = M() post_trace = SubclassM()
使用GraphModule
中的to_folder
函数
GraphModule.to_folder()
是 GraphModule
中的一个方法,允许你将生成的 FX 代码导出到一个文件夹中。虽然通常复制前向传递代码就足够了(如在打印生成的代码中所示),但使用 to_folder
方法来检查模块和参数可能更方便。
m = symbolic_trace(M()) m.to_folder("foo", "Bar") from foo import Bar y = Bar()
运行上述示例后,我们可以查看并修改 foo/module.py
中的代码(例如添加 print
语句或使用 pdb
),以便调试生成的代码。
调试变换
现在我们已经确定某个转换生成了错误的代码,是时候开始调试这个转换本身了。首先,我们将查看文档中的符号跟踪的限制部分。确认跟踪功能正常后,我们需要找出在GraphModule
转换过程中出现了什么问题。可能在编写转换中有快速解答,但如果找不到的话,有几种方法可以检查我们的追踪模块:
# Sample Module class M(torch.nn.Module): def forward(self, x, y): return x + y # Create an instance of `M` m = M() # Symbolically trace an instance of `M` (returns a GraphModule). In # this example, we'll only be discussing how to inspect a # GraphModule, so we aren't showing any sample transforms for the # sake of brevity. traced = symbolic_trace(m) # Print the code produced by tracing the module. print(traced) # The generated `forward` function is: """ def forward(self, x, y): add = x + y; x = y = None return add """ # Print the internal Graph. print(traced.graph) # This print-out returns: """ graph(): %x : [num_users=1] = placeholder[target=x] %y : [num_users=1] = placeholder[target=y] %add : [num_users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {}) return add """ # Print a tabular representation of the internal Graph. traced.graph.print_tabular() # This gives us: """ opcode name target args kwargs ------------- ------ ----------------------- ------ -------- placeholder x x () {} placeholder y y () {} call_function add <built-in function add> (x, y) {} output output output (add,) {} """
使用上述工具函数,我们可以对比应用变换前后的跟踪模块。有时候,简单的视觉比较就能帮助我们找到错误。如果还是无法确定问题所在,可以考虑使用像 pdb
这样的调试器作为下一步的解决方案。
参考上面的例子,考虑以下代码:
# Sample user-defined function def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module: # Get the Graph from our traced Module g = tracer_class().trace(module) """ Transformations on `g` go here """ return fx.GraphModule(module, g) # Transform the Graph transformed = transform_graph(traced) # Print the new code after our transforms. Check to see if it was # what we expected print(transformed)
根据上述示例,假设调用print(traced)
显示我们的转换中存在错误。我们希望通过调试器找出问题所在。启动一个pdb
会话后,在transform_graph(traced)
上设置断点,并按s
键“步入”该调用,以查看转换过程中发生了什么。
我们也可以通过修改 print_tabular
方法来打印图中节点的不同属性,例如节点的 input_nodes
和 users
属性。
可用的调试工具
最常用的 Python 调试器是 pdb。你可以在命令行中输入python -m pdb FILENAME.py
来以“调试模式”启动你的程序,其中FILENAME
是你想要调试的文件名。之后,你可以使用 pdb
的调试命令逐步运行你的程序。通常在开始时设置一个断点(b LINE-NUMBER
),然后调用c
让程序运行到该断点处,这样可以避免你需要逐行执行代码(使用 s
或 n
)来到达你想要检查的代码部分。或者,你可以在你希望中断的地方之前写入import pdb; pdb.set_trace()
。如果你添加了pdb.set_trace()
,当你运行程序时它会自动以调试模式启动。(换句话说,你可以直接在命令行中输入 python FILENAME.py
而不是 python -m pdb FILENAME.py
)。一旦你在调试模式下运行你的文件,你可以使用某些命令逐步执行代码并检查程序的内部状态。有许多关于pdb
的优秀教程在线上,包括RealPython的“使用Pdb进行Python调试”。
像 PyCharm 或 VSCode 这样的集成开发环境通常都内置了调试器。在你的 IDE 中,你可以选择两种方式来使用 pdb
:a) 在 IDE 中打开一个终端窗口(例如,在 VSCode 中为 View → Terminal),或者 b) 使用 IDE 内置的图形化调试工具(该工具通常基于 pdb
)。
符号跟踪的局限性
FX 使用一种称为 符号跟踪 (又称 符号执行)的系统,以可转换和分析的形式捕获程序语义。该系统通过执行程序(实际上是torch.nn.Module
或函数)并记录操作来进行跟踪。在执行过程中,流经程序的数据不是实际数据,而是符号(在FX中称为 Proxy
),因此它是符号化的。
虽然符号跟踪适用于大多数神经网络代码,但它也存在一些局限性。
动态控制流
符号跟踪的主要限制在于它当前不支持动态控制流,例如循环或if
语句,这些语句中的条件可能会根据程序的输入值发生变化。
例如,我们来看看以下程序:
def func_to_trace(x): if x.sum() > 0: return torch.relu(x) else: return torch.neg(x) traced = torch.fx.symbolic_trace(func_to_trace) """ <...> File "dyn.py", line 6, in func_to_trace if x.sum() > 0: File "pytorch/torch/fx/proxy.py", line 155, in __bool__ return self.tracer.to_bool(self) File "pytorch/torch/fx/proxy.py", line 85, in to_bool raise TraceError('symbolically traced variables cannot be used as inputs to control flow') torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow """
如果语句的条件依赖于 x.sum()
的值,而 x.sum()
又依赖于输入参数 x
的值。由于 x
可以变化(例如传递一个新的输入张量给被跟踪的函数),这种情况被称为动态控制流。追踪回溯会向上遍历你的代码,展示出这种情形发生的位置。
静态控制流
另一方面,所谓的静态控制流是受支持的。静态控制流指的是那些在不同调用中值不会改变的循环或if
语句。通常,在 PyTorch 程序中,这种类型的控制流出现在根据超参数对模型架构做出决策的代码里。例如:
import torch import torch.fx class MyModule(torch.nn.Module): def __init__(self, do_activation : bool = False): super().__init__() self.do_activation = do_activation self.linear = torch.nn.Linear(512, 512) def forward(self, x): x = self.linear(x) # This if-statement is so-called static control flow. # Its condition does not depend on any input values if self.do_activation: x = torch.relu(x) return x without_activation = MyModule(do_activation=False) with_activation = MyModule(do_activation=True) traced_without_activation = torch.fx.symbolic_trace(without_activation) print(traced_without_activation.code) """ def forward(self, x): linear_1 = self.linear(x); x = None return linear_1 """ traced_with_activation = torch.fx.symbolic_trace(with_activation) print(traced_with_activation.code) """ import torch def forward(self, x): linear_1 = self.linear(x); x = None relu_1 = torch.relu(linear_1); linear_1 = None return relu_1 """
if 语句 if self.do_activation
不依赖于任何函数输入,因此它是静态的。do_activation
可以被视为一个超参数,并且不同实例的代码会根据该参数的不同值而有所不同(例如:MyModule
的不同实例)。这是一个由符号跟踪支持的有效模式。
许多动态控制流的实例实际上具有静态控制流的语义。通过移除对输入值的数据依赖(例如,将值移动到Module
属性或在符号跟踪期间将具体值绑定到参数),可以使这些实例支持符号跟踪:
def f(x, flag): if flag: return x else: return x*2 fx.symbolic_trace(f) # Fails! fx.symbolic_trace(f, concrete_args={'flag': True})
对于真正动态的控制流,包含此代码的程序部分可以被追踪为对方法(参见使用Tracer类自定义跟踪)或函数(参见wrap()
)的调用,而不是直接通过它们进行追踪。
非torch
函数
FX 使用 __torch_function__
作为拦截调用的机制(参见技术概述以获取更多信息)。一些函数,如内置 Python 函数或 math
模块中的函数,并未被 __torch_function__
支持,但我们仍然希望在符号跟踪中捕获它们。例如:
import torch import torch.fx from math import sqrt def normalize(x): """ Normalize `x` by the size of the batch dimension """ return x / sqrt(len(x)) # It's valid Python code normalize(torch.rand(3, 4)) traced = torch.fx.symbolic_trace(normalize) """ <...> File "sqrt.py", line 9, in normalize return x / sqrt(len(x)) File "pytorch/torch/fx/proxy.py", line 161, in __len__ raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want " RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope """
此错误表明内置函数 len
不被支持。我们可以使用 wrap()
API,将这类函数在追踪信息中以直接调用的形式记录下来。
torch.fx.wrap('len') torch.fx.wrap('sqrt') traced = torch.fx.symbolic_trace(normalize) print(traced.code) """ import math def forward(self, x): len_1 = len(x) sqrt_1 = math.sqrt(len_1); len_1 = None truediv = x / sqrt_1; x = sqrt_1 = None return truediv """
使用Tracer
类进行自定义追踪
The Tracer
类是 symbolic_trace
的实现基础。可以通过继承和扩展 Tracer 类来自定义追踪行为,例如:
class MyCustomTracer(torch.fx.Tracer): # Inside here you can override various methods # to customize tracing. See the `Tracer` API # reference pass # Let's use this custom tracer to trace through this module class MyModule(torch.nn.Module): def forward(self, x): return torch.relu(x) + torch.ones(3, 4) mod = MyModule() traced_graph = MyCustomTracer().trace(mod) # trace() returns a Graph. Let's wrap it up in a # GraphModule to make it runnable traced = torch.fx.GraphModule(mod, traced_graph)
叶模块
叶模块是指在符号跟踪中以调用形式出现而非被追踪的模块。默认情况下,叶模块包括标准的 torch.nn
模块实例。例如:
class MySpecialSubmodule(torch.nn.Module): def forward(self, x): return torch.neg(x) class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(3, 4) self.submod = MySpecialSubmodule() def forward(self, x): return self.submod(self.linear(x)) traced = torch.fx.symbolic_trace(MyModule()) print(traced.code) # `linear` is preserved as a call, yet `submod` is traced though. # This is because the default set of "Leaf Modules" includes all # standard `torch.nn` modules. """ import torch def forward(self, x): linear_1 = self.linear(x); x = None neg_1 = torch.neg(linear_1); linear_1 = None return neg_1 """
可以通过覆盖Tracer.is_leaf_module()
来自定义叶模块的集合。
杂项
-
目前,张量构造函数(例如
torch.zeros
、torch.ones
、torch.rand
、torch.randn
和torch.sparse_coo_tensor
)不可追踪。-
确定性的构造函数(如
zeros
和ones
)可以使用,它们生成的值将作为常量嵌入到跟踪中。只有当这些构造函数的参数引用动态输入大小时,才会出现问题。在这种情况下,ones_like
或zeros_like
可能是更好的替代方案。 -
非确定性构造函数(
rand
,randn
)会在跟踪中嵌入一个单一的随机值,这可能不是预期的行为。一种解决方法是将torch.randn
包裹在一个torch.fx.wrap
函数中,并调用该函数。
@torch.fx.wrap def torch_randn(x, shape): return torch.randn(shape) def f(x): return x + torch_randn(x, 5) fx.symbolic_trace(f)
-
这个问题可能会在未来的一个版本中得到修复。
-
-
类型注解
-
Python 3 风格的类型注解(例如
func(x: torch.Tensor, y: int) -> torch.Tensor
)是受支持的,并且会被符号跟踪保留。 -
当前不支持 Python 2 风格的注释类型标注,例如:
# type: (torch.Tensor, int) -> torch.Tensor
-
当前不支持在函数内部的局部名称上使用注解。
-
-
关于
training
标志和子模块的相关注意事项-
当使用像
torch.nn.functional.dropout
这样的功能时,训练参数通常会以self.training
的形式传递。在FX追踪过程中,这个参数很可能被固定为一个常量值。
import torch import torch.fx class DropoutRepro(torch.nn.Module): def forward(self, x): return torch.nn.functional.dropout(x, training=self.training) traced = torch.fx.symbolic_trace(DropoutRepro()) print(traced.code) """ def forward(self, x): dropout = torch.nn.functional.dropout(x, p = 0.5, training = True, inplace = False); x = None return dropout """ traced.eval() x = torch.randn(5, 3) torch.testing.assert_close(traced(x), x) """ AssertionError: Tensor-likes are not close! Mismatched elements: 15 / 15 (100.0%) Greatest absolute difference: 1.6207983493804932 at index (0, 2) (up to 1e-05 allowed) Greatest relative difference: 1.0 at index (0, 0) (up to 0.0001 allowed) """
-
然而,当使用标准的
nn.Dropout()
子模块时,训练标志会被封装。由于保留了nn.Module
对象模型,因此可以修改这个标志。
class DropoutRepro2(torch.nn.Module): def __init__(self): super().__init__() self.drop = torch.nn.Dropout() def forward(self, x): return self.drop(x) traced = torch.fx.symbolic_trace(DropoutRepro2()) print(traced.code) """ def forward(self, x): drop = self.drop(x); x = None return drop """ traced.eval() x = torch.randn(5, 3) torch.testing.assert_close(traced(x), x)
-
由于这种差异,考虑将与
training
标志动态交互的模块标记为叶模块。
API参考
- torch.fx.symbolic_trace(root, concrete_args=None)[源代码]
-
符号追踪API
给定一个
nn.Module
或函数实例root
,此函数将返回一个通过在跟踪root
时记录的操作构建的GraphModule
。concrete_args
允许你部分地专门化你的函数,以移除控制流或数据结构。例如:
def f(a, b): if b == True: return a else: return a*2
由于存在控制流,FX 通常无法进行追踪。但是,我们可以使用 concrete_args 来专门针对 b 的值进行追踪。
f = fx.symbolic_trace(f, concrete_args={'b': False}) assert f(3, False) == 6
请注意,尽管你可以传入不同的b值,但这些值将会被忽略。
我们还可以使用concrete_args来移除函数中的数据结构处理。这会利用pytrees展开输入参数。为了避免过度专门化,对于不需要专门化的值应传递fx.PH。例如:
def f(x): out = 0 for v in x.values(): out += v return out f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}}) assert f({'a': 1, 'b': 2, 'c': 4}) == 7
- 参数
-
-
root (Union[torch.nn.Module, Callable]) – 需要被追踪并转换为图表示的模块或函数。
-
concrete_args (Optional[Dict[str, any]]) – 需要部分特化的输入参数
-
- 返回值
-
由root记录的操作生成的模块。
- 返回类型
注意
此 API 的向后兼容性得到了保证。
- torch.fx.wrap(fn_or_name)[源代码]
-
此函数可以在模块级别的作用域中调用,用于将 fn_or_name 注册为“叶子函数”。一个“叶子函数”在 FX 跟踪中会保留为 CallFunction 节点,而不会被进一步追踪。
# foo/bar/baz.py def my_custom_function(x, y): return x * x + y * y torch.fx.wrap('my_custom_function') def fn_to_be_traced(x, y): # When symbolic tracing, the below call to my_custom_function will be inserted into # the graph rather than tracing it. return my_custom_function(x, y)
此函数还可以等效地用作装饰器:
# foo/bar/baz.py @torch.fx.wrap def my_custom_function(x, y): return x * x + y * y
一个被包装的函数可以视为一个“叶函数”,类似于“叶模块”的概念。也就是说,在FX跟踪中,这些函数会被保留在调用层面,而不进行内部追踪。
- 参数
-
fn_or_name (Union[str, Callable]) – 在调用时插入图中的函数或全局函数的名称
注意
此 API 的向后兼容性得到了保证。
- 类torch.fx.GraphModule(*args, **kwargs)[源代码]
-
GraphModule 是由 fx.Graph 生成的一个 nn.Module。它拥有一个
graph
属性,以及从该graph
生成的code
和forward
属性。警告
当
graph
被重新赋值时,code
和forward
将自动再生。然而,如果你在不重新赋值graph
属性的情况下编辑了graph
的内容,你需要调用recompile()
来更新生成的代码。注意
此 API 的向后兼容性得到了保证。
- __init__(root, graph, class_name='GraphModule')[源代码]
-
创建一个GraphModule。
- 参数
-
-
root (Union[torch.nn.Module, Dict[str, Any]) –
root
可以是一个 nn.Module 实例,也可以是将字符串映射到任意属性类型的字典。如果root
是一个 Module,则图中的节点的target
字段中引用的基于 Module 的对象(通过全限定名)将会从root
的 Module 层次结构复制到 GraphModule 的模块层次结构。如果root
是一个字典,则节点中的target
中的全限定名将直接在字典中查找,该对象将会被复制到 GraphModule 模块层次结构中的适当位置。 -
graph (Graph) –
graph
包含此 GraphModule 用于代码生成的节点 -
class_name (str) –
name
表示此 GraphModule 的调试名称。如果未设置,所有错误消息将默认显示为来自GraphModule
。将其设置为root
的原始名称或在转换上下文中具有意义的名称可能会有所帮助。
-
注意
此 API 的向后兼容性得到了保证。
- add_submodule(target, m)[源代码]
-
将给定的子模块添加到
self
。如果它们是
target
的子路径且尚不存在,这会安装空的模块。- 参数
- 返回值
-
- 子模块是否可以被插入。
-
为了使该方法返回 True,
target
标注的链中每个对象必须要么尚未存在,要么引用一个nn.Module
(而不是参数或其它属性)。
- 返回类型
注意
此 API 的向后兼容性得到了保证。
- 属性 代码: str
-
返回由该
GraphModule
底层的Graph
生成的 Python 代码。
- delete_all_unused_submodules()[源代码]
-
从
self
中删除所有未使用的子模块。如果满足以下任一条件,则认为一个模块被“使用”了:1. 它有被使用的子模块 2. 其 forward 方法通过一个
call_module
节点直接调用 3. 它有一个非模块属性,该属性从get_attr
节点中被使用此方法可以被调用以清理一个
nn.Module
,而无需手动为每个未使用的子模块调用delete_submodule
。注意
此 API 的向后兼容性得到了保证。
- delete_submodule(target)[源代码]
-
从
self
中删除指定的子模块。如果
target
不是有效的目标,模块将不会被删除。- 参数
-
target (str) – 新子模块的完整qualified字符串名称(参见
nn.Module.get_submodule
中的示例,了解如何指定完整qualified字符串。) - 返回值
-
- 目标字符串是否引用了某项内容
-
这个返回值表示我们要删除的是否是一个有效的子模块引用。如果返回值是
False
,则说明target
不是一个有效的子模块引用。
- 返回类型
注意
此 API 的向后兼容性得到了保证。
- property graph: Graph
-
返回此
GraphModule
的底层Graph
- print_readable(print_output=True, include_stride=False, include_device=False, colored=False)[源代码]
-
返回当前 GraphModule 及其子模块生成的 Python 代码
警告
此 API 是实验性的,且不支持向下兼容。
- recompile()[源代码]
-
根据其
graph
属性重新编译此 GraphModule。在编辑了包含的graph
之后,应调用此方法,否则此GraphModule
生成的代码将变得过时。注意
此 API 的向后兼容性得到了保证。
- 返回类型
-
Python代码
- to_folder(folder, module_name='FxModule')[源代码]
-
-
将模块输出到指定的
folder
文件夹中,并使用module_name
进行命名,以便于后续操作。 -
使用
from <folder> import <module_name>
进行导入参数:
folder (Union[str, os.PathLike]): 写入代码的目标文件夹
-
module_name (str): 指定的顶级名称,用于标识
Module
-
写代码
-
module_name (str): 指定的顶级名称,用于标识
警告
此 API 是实验性的,且不支持向下兼容。
-
将模块输出到指定的
- 类torch.fx.Graph(owning_module=None, tracer_cls=None, tracer_extras=None)[源代码]
-
Graph
是 FX 中间表示中主要的数据结构。它由一系列的Node
组成,每个节点代表调用点(或其他语法构造)。这些Node
一起构成一个有效的 Python 函数。例如,以下示例代码
import torch import torch.fx class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) def forward(self, x): return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3) m = MyModule() gm = torch.fx.symbolic_trace(m)
将生成如下图形:
print(gm.graph)
graph(x): %linear_weight : [num_users=1] = self.linear.weight %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {}) %linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {}) %relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {}) %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1}) %topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {}) return topk_1
有关
Graph
中操作的语义,请参见Node
。注意
此 API 的向后兼容性得到了保证。
- __init__(owning_module=None, tracer_cls=None, tracer_extras=None)[源代码]
-
创建一个空的图。
注意
此 API 的向后兼容性得到了保证。
- call_function(the_function, args=None, kwargs=None, type_expr=None)[源代码]
-
在
Graph
中插入一个call_function
Node
。这个call_function
节点表示对由the_function
指定的Python可调用对象的调用。- 参数
-
-
the_function (Callable[..., Any]) – 需要调用的函数。它可以是任何 PyTorch 操作符、Python 函数或
builtins
和operator
命名空间中的成员。 -
args (Optional[Tuple[Argument, ...]]) – 被调用函数所需的位置参数。
-
kwargs (Optional[Dict[str, Argument]]) – 传递给调用函数的关键词参数
-
type_expr (Optional[Any]) – 表示该节点输出的 Python 类型的可选类型注解。
-
- 返回值
-
新创建并插入的
call_function
节点。 - 返回类型
注意
此方法的插入点和类型表达式的规则与
Graph.create_node()
方法的规则相同。注意
此 API 的向后兼容性得到了保证。
- call_method(method_name, args=None, kwargs=None, type_expr=None)[源代码]
-
在
Graph
中插入一个call_method
Node
。这个call_method
节点表示对args
中的第0个元素调用指定的方法。- 参数
-
-
method_name (str) – 指定要应用于 self 参数的方法名。例如,如果 args[0] 是一个表示
Tensor
的Node
,则要对该Tensor
调用relu()
方法,应将relu
传递给method_name
。 -
args (Optional[Tuple[Argument, ...]]) – 被调用方法所需的位置参数。请注意,这 应该 包括一个
self
参数。 -
kwargs (Optional[Dict[str, Argument]]) – 传递给被调用方法的关键词参数
-
type_expr (Optional[Any]) – 表示该节点输出的 Python 类型的可选类型注解。
-
- 返回值
-
新创建并插入的
call_method
节点。 - 返回类型
注意
此方法的插入点和类型表达式的规则与
Graph.create_node()
方法的规则相同。注意
此 API 的向后兼容性得到了保证。
- call_module(module_name, args=None, kwargs=None, type_expr=None)[源代码]
-
在
Graph
中插入一个call_module
Node
。这个call_module
节点表示对Module
层次结构中的某个Module
的forward()
函数的调用。- 参数
-
-
module_name (str) – 要调用的
Module
的全限定名称。例如,如果被跟踪的Module
包含一个名为foo
的子模块,并且该子模块包含一个名为bar
的子模块,则应将全限定名称foo.bar
作为参数传递给module_name
,以调用该模块。 -
args (Optional[Tuple[Argument, ...]]) – 要传递给被调用方法的位置参数。请注意,不应包含
self
参数。 -
kwargs (Optional[Dict[str, Argument]]) – 传递给被调用方法的关键词参数
-
type_expr (Optional[Any]) – 表示该节点输出的 Python 类型的可选类型注解。
-
- 返回值
-
新创建并插入的
call_module
节点。 - 返回类型
注意
此方法的插入点和类型表达式的规则与
Graph.create_node()
方法的规则相同。注意
此 API 的向后兼容性得到了保证。
- create_node(op, target, args=None, kwargs=None, name=None, type_expr=None)[源代码]
-
创建一个
Node
并将其添加到当前的Graph
插入点。请注意,可以通过Graph.inserting_before()
和Graph.inserting_after()
来设置当前的插入点。- 参数
-
-
op (str) – 该节点的指令码。可以是 ‘call_function’, ‘call_method’, ‘get_attr’, ‘call_module’, ‘placeholder’ 或 ‘output’ 中的一个。这些指令码的具体含义在
Graph
的文档字符串中进行了描述。 -
args (Optional[Tuple[Argument, ...]]) – 此节点的参数以元组形式表示。
-
kwargs (Optional[Dict[str, Argument]]) – 此节点的 kwargs 参数
-
name (可选[str]) – 可选的字符串名称,用于标识
Node
。这将影响在生成的 Python 代码中该值的变量名。 -
type_expr (Optional[Any]) – 表示该节点输出的 Python 类型的可选类型注解。
-
- 返回值
-
新创建并插入的节点。
- 返回类型
注意
此 API 的向后兼容性得到了保证。
- eliminate_dead_code(is_impure_node=None)[源代码]
-
根据每个节点的用户数量和是否具有任何副作用,删除图中的所有死代码。在调用之前,必须先对图进行拓扑排序。
- 参数
-
-
is_impure_node (Optional[Callable[ [Node] , bool]] ) – 返回一个函数
-
None ((节点是否为不纯。如果是)) –
-
到 (即默认行为)–
-
Node.is_impure (使用)–
-
- 返回值
-
该pass是否改变了图。
- 返回类型
示例:
在删除死代码之前,a = x + 1 中的 a 没有被使用,因此可以从图中移除而不产生任何影响。
def forward(self, x): a = x + 1 return x + self.attr_1
消除死代码后,a = x + 1 被移除,而forward的其他部分保持不变。
def forward(self, x): return x + self.attr_1
警告
死代码消除具有一些启发式方法以避免移除有副作用的节点(参见 Node.is_impure),但总体覆盖率很低。因此,除非你知道你的FX图完全由函数操作组成,或者你提供了自己的自定义函数来检测有副作用的节点,否则你应该假设此方法不可靠。
注意
此 API 的向后兼容性得到了保证。
- erase_node(to_erase)[源代码]
-
从
Graph
中移除一个Node
。如果Graph
中仍然存在该节点的使用者,则会抛出异常。- 参数
-
to_erase (Node) - 需要从
Graph
中删除的节点。
注意
此 API 的向后兼容性得到了保证。
- find_nodes(*, op, target=None, sort=True)[源代码]
-
支持快速查询节点
- 参数
- 返回值
-
具有指定操作和目标的节点集合的可迭代对象。
警告
此 API 是实验性的,且不支持向下兼容。
- get_attr(qualified_name, type_expr=None)[源代码]
-
在图中插入一个
get_attr
节点。这个get_attr
节点表示从Module
层次结构中获取属性。- 参数
-
-
qualified_name (str) – 要检索的属性的完全限定名称。例如,如果被跟踪模块包含一个名为
foo
的子模块,该子模块又包含一个名为bar
的子模块,并且这个子模块有一个名为baz
的属性,则应将完全限定名称foo.bar.baz
作为参数传递。 -
type_expr (Optional[Any]) – 表示该节点输出的 Python 类型的可选类型注解。
-
- 返回值
-
新创建并插入的
get_attr
节点。 - 返回类型
注意
此方法的插入点和类型表达式规则与
Graph.create_node
方法相同。注意
此 API 的向后兼容性得到了保证。
- graph_copy(g, val_map, return_output_node=False)[源代码]
-
将给定图形中的所有节点复制到
self
。- 参数
- 返回值
-
如果
g
有一个output
节点,那么现在与g
的输出值等价的self
中的值。否则为None
。 - 返回类型
-
Optional[Union[Tuple[Any, ...], List[Any], Dict[str, Any], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]
注意
此 API 的向后兼容性得到了保证。
- inserting_after(n=None)[源代码]
-
- 设置创建节点及其相关方法在图中的插入位置。
-
当在“with”语句中使用时,这会暂时设置插入点,并在“with”语句结束时恢复:
with g.inserting_after(n): ... # inserting after node n ... # insert point restored to what it was previously g.inserting_after(n) # set the insert point permanently
参数:
- n (Optional[Node]): 要在其之前插入的节点。如果为 None,则在末尾插入。
-
图形的开始部分。
- 返回:
-
一个资源管理器,将在
__exit__
时恢复插入位置。
注意
此 API 的向后兼容性得到了保证。
- inserting_before(n=None)[源代码]
-
- 设置创建节点及其相关方法在图中的插入位置。
-
当在“with”语句中使用时,这会暂时设置插入点,并在“with”语句结束时恢复:
with g.inserting_before(n): ... # inserting before node n ... # insert point restored to what it was previously g.inserting_before(n) # set the insert point permanently
参数:
- n (Optional[Node]): 要在其前插入的节点。若为 None,则在开头插入。
-
图形的开始部分。
- 返回:
-
一个资源管理器,将在
__exit__
时恢复插入位置。
注意
此 API 的向后兼容性得到了保证。
- lint()[源代码]
-
对此图进行各种检查,确保其结构正确。具体包括:
- 确认节点的拥有权是否正确(由该图所有)
- 验证节点是否按拓扑顺序排列
- 如果此图属于一个GraphModule,则检查目标是否存在于此GraphModule中注意
此 API 的向后兼容性得到了保证。
- node_copy(node, arg_transform=<function Graph.<lambda>>)[源代码]
-
将一个节点从一个图复制到另一个图中。
arg_transform
需要将源节点图中的参数转换为目标图中的相应参数。示例:# Copying all the nodes in `g` into `new_graph` g : torch.fx.Graph = ... new_graph = torch.fx.graph() value_remap = {} for node in g.nodes: value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n])
- 参数
- 返回类型
注意
此 API 的向后兼容性得到了保证。
- 属性节点: _node_list
-
获取组成此图的节点列表。
请注意,这个
Node
列表表示的是一个双向链表。在迭代过程中进行的修改操作(如删除或添加一个节点)是安全的。- 返回值
-
一个由节点组成的双向链表。注意,可以通过调用
reversed
来切换迭代顺序。
- on_generate_code(make_transformer)[源代码]
-
在生成 Python 代码时注册一个变换函数
- 参数:
-
- make_transformer ( Callable[[ Optional[TransformCodeFunc] ], TransformCodeFunc] ):
-
一个用于返回并注册代码转换器的函数。该函数被on_generate_code调用以获取代码转换器。
此函数还将当前注册的代码转换器(如果没有注册则为 None)作为输入参数传递,以防止意外覆盖。这样可以方便地将多个代码转换器串联起来使用。
- 返回:
-
一个上下文管理器,在使用 with 语句时,会自动恢复之前注册的代码转换器。
示例:
gm: fx.GraphModule = ... # This is a code transformer we want to register. This code # transformer prepends a pdb import and trace statement at the very # beginning of the generated torch.fx code to allow for manual # debugging with the PDB library. def insert_pdb(body): return ["import pdb; pdb.set_trace()\n", *body] # Registers `insert_pdb`, and overwrites the current registered # code transformer (given by `_` to the lambda): gm.graph.on_generate_code( lambda _: insert_pdb ) # Or alternatively, registers a code transformer which first # runs `body` through existing registered transformer, then # through `insert_pdb`: gm.graph.on_generate_code( lambda current_trans: ( lambda body: insert_pdb( current_trans(body) if current_trans else body ) ) ) gm.recompile() gm(*inputs) # drops into pdb
此函数还可以用作上下文管理器,能够自动恢复之前注册的代码转换器。
# ... continue from previous example with gm.graph.on_generate_code(lambda _: insert_pdb): # do more stuff with `gm`... gm.recompile() gm(*inputs) # drops into pdb # now previous code transformer is restored (but `gm`'s code with pdb # remains - that means you can run `gm` with pdb here too, until you # run next `recompile()`).
警告
此 API 是实验性的,且不支持向下兼容。
- output(result, type_expr=None)[源代码]
-
在
Graph
中插入一个output
Node
。一个output
节点表示Python代码中的return
语句。result
是要返回的结果。- 参数
-
-
result (参数) - 需要返回的值。
-
type_expr (Optional[Any]) – 表示该节点输出的 Python 类型的可选类型注解。
-
注意
此方法的插入点和类型表达式规则与
Graph.create_node
方法相同。注意
此 API 的向后兼容性得到了保证。
- placeholder(name, type_expr=None, default_value)[源代码]
-
在图中插入一个
占位符
节点,它表示函数的输入。- 参数
-
-
name (str) – 输入值的名称。这对应于该
Graph
表示的函数的位置参数。 -
type_expr (Optional[Any]) – 可选的类型注解,表示此节点输出的 Python 类型。在某些情况下(例如,在 TorchScript 编译过程中使用该函数时),这对于正确的代码生成是必需的。
-
default_value (Any) – 该函数参数的默认值。注意:为了允许 None 作为默认值,应将 inspect.Signature.empty 传递给此参数以指定参数没有默认值。
-
- 返回类型
注意
此方法的插入点和类型表达式规则与
Graph.create_node
方法相同。注意
此 API 的向后兼容性得到了保证。
- print_tabular()[源代码]
-
以表格形式打印图的中间表示。请注意,使用此功能需要先安装
tabulate
模块。注意
此 API 的向后兼容性得到了保证。
- process_inputs(*args)[源代码]
-
处理参数,以便可以将其传递给FX图。
警告
此 API 是实验性的,且不支持向下兼容。
- process_outputs(out)[源代码]
-
警告
此 API 是实验性的,且不支持向下兼容。
- python_code(root_module, *, verbose=False, include_stride=False, include_device=False, colored=False)[源代码]
-
将这个
Graph
转换为有效的 Python 代码。- 参数
-
root_module (str) – 根模块的名称,用于查找qualified name目标。这通常是‘self’。
- 返回值
-
src: 表示对象的 Python 源代码;globals: 包含 src 中全局名称及其引用对象的字典。
- 返回类型
-
一个PythonCode对象,包含两个字段。
注意
此 API 的向后兼容性得到了保证。
- set_codegen(codegen)[源代码]
-
警告
此 API 是实验性的,且不支持向下兼容。
- classtorch.fx.Node(graph, name, op, target, args, kwargs, return_type=None)[源代码]
-
Node
是表示Graph
中各个操作的数据结构。通常情况下,Node
表示对各种实体(如运算符、方法和模块)的调用位置。一些例外情况包括指定函数输入和输出的节点。每个Node
都有一个由其op
属性定义的函数。对于不同的op
值,Node
的语义如下:-
placeholder
用于表示函数输入,name
属性指定该值的名称。target
同样是参数的名称。而args
可以包含:1)无内容,或 2)单个默认参数。kwargs
不重要。占位符对应于图打印中的函数参数(例如x
)。 -
get_attr
从模块层次结构中获取一个参数。其中,name
是获取结果的变量名,target
是该参数在模块层次结构中的完整限定名称。args
和kwargs
则无关紧要。 -
call_function
将一个自由函数应用于某些值。其中name
是要分配给该值的名称,target
是要应用的函数。args
和kwargs
表示传递给函数的参数,遵循 Python 的调用约定。 -
call_module
使用给定的参数调用模块层次结构中的forward()
方法。name
与之前相同。target
是要调用的模块在模块层次结构中的全限定名称,而args
和kwargs
表示用于调用该模块的参数(不包括 self 参数)。 -
call_method
在一个值上调用方法。同样地,name
也类似。target
是要应用于self
参数的方法的字符串名称。args
和kwargs
表示调用该方法时的所有参数,包括 self 参数 -
output
包含了被追踪函数的返回值,存储在其args[0]
属性中。这与 Graph 输出中的“return”语句相对应。
注意
此 API 的向后兼容性得到了保证。
- 属性all_input_nodes:List[Node]
-
返回所有输入到该节点的节点。这相当于遍历
args
和kwargs
,并只收集其中为节点的值。- 返回值
-
按顺序出现在此
Node
的args
和kwargs
中的Nodes
列表。
- append(x)[源代码]
-
在此节点的节点列表中插入
x
。相当于执行self.next.prepend(x)
。- 参数
-
x (Node) – 需要放置在此节点后的节点。此节点和目标节点必须属于同一张图。
注意
此 API 的向后兼容性得到了保证。
- propertyargs:Tuple[Optional[Union[Tuple[Any,...], List[Any], Dict[str, Any], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]],...]
-
这是传递给
Node
的参数元组。参数的具体含义由节点的操作码决定。更多详细信息,请参阅Node
的文档字符串。允许对此属性进行赋值。在赋值时,所有使用情况和用户的相关记录会自动更新。
- format_node(placeholder_names=None, maybe_return_typename=None)[源代码]
-
返回
self
的描述性字符串表示。此方法可以在不带任何参数的情况下用作调试工具。
此函数还用于
Graph
类的内部__str__
方法中。在该图周围 GraphModule 中自动生成的forward
函数签名由placeholder_names
和maybe_return_typename
中的字符串组成。placeholder_names
和maybe_return_typename
不应以其他方式使用。- 参数
- 返回值
-
-
如果 1) 我们使用
format_node
作为内部辅助函数 -
在
Graph
的__str__
方法中,如果self
是一个占位符节点,则返回None
。否则,返回当前节点的描述性字符串表示。
-
如果 1) 我们使用
- 返回类型
注意
此 API 的向后兼容性得到了保证。
- insert_arg(idx, arg)[源代码]
-
在指定索引的位置向参数列表中插入一个参数。
- 参数
-
-
idx (int) – 在
self.args
中要插入的元素之前的索引位置。 -
arg (参数) – 需要插入到
args
中的新参数值
-
注意
此 API 的向后兼容性得到了保证。
- is_impure()[源代码]
-
返回该操作是否为不纯操作,即该操作是一个占位符或输出,或者是不纯的 call_function 或 call_module。
- 返回值
-
如果是不纯的操作或不是纯函数。
- 返回类型
警告
此 API 是实验性的,且不支持向下兼容。
- propertykwargs:Dict[str, Optional[Union[Tuple[Any, ...], List[Any], Dict[str, Any], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload,SymInt, SymBool, SymFloat]]
-
这是传递给
Node
的关键词参数字典。参数的具体含义取决于节点的操作码。更多详细信息,请参阅Node
的文档字符串。允许对此属性进行赋值。在赋值时,所有使用情况和用户的相关记录会自动更新。
- 属性next:Node
-
返回链表中下一个
Node
。- 返回值
-
在节点链表中的下一个
Node
。
- normalized_arguments(root, arg_types=None, kwarg_types=None, normalize_to_only_use_kwargs=False)[源代码]
-
返回标准化后的参数给 Python 目标。这意味着 args/kwargs 将与模块或函数的签名匹配,并在 normalize_to_only_use_kwargs 为 true 时按位置顺序返回仅 kwargs 参数。还会填充默认值。不支持仅限位置的参数和可变参数。
支持模块调用。
可能需要arg_types和kwarg_types来区分函数的重载。
- 参数
-
-
root (torch.nn.Module) – 用于解析模块目标的基础模块。
-
arg_types (Optional[Tuple[Any]]) – 参数类型的元组
-
kwarg_types (Optional[Dict[str, Any]]) – kwargs 的类型参数字典
-
normalize_to_only_use_kwargs (bool) – 是否将规范改为只使用关键字参数。
-
- 返回值
-
返回 NamedTuple ArgsKwargsPair,如果失败则返回 None。
- 返回类型
警告
此 API 是实验性的,且不支持向下兼容。
- prepend(x)[源代码]
-
在此节点前面的图的节点列表中插入 x。 示例:
Before: p -> self bx -> x -> ax After: p -> x -> self bx -> ax
- 参数
-
x (Node) – 需要放置在当前节点之前的一个节点,且必须属于同一张图。
注意
此 API 的向后兼容性得到了保证。
- 属性prev: Node
-
返回链表中之前的
Node
节点。- 返回值
-
前一个
Node
,在节点的链表中。
- replace_all_uses_with(replace_with, delete_user_cb=<function Node.<lambda>>, *, propagate_meta=False)[源代码]
-
将图中所有的
self
替换为节点replace_with
。- 参数
- 返回值
-
在此更改中涉及的节点列表。
- 返回类型
-
注意
此 API 的向后兼容性得到了保证。
- replace_input_with(old_input, new_input)[源代码]
-
遍历
self
的输入节点,将其中所有的old_input
替换为new_input
。注意
此 API 的向后兼容性得到了保证。
- 属性stack_trace:Optional[str]
-
返回在跟踪过程中记录的 Python 堆栈跟踪(如果有)。通常,当使用 fx.Tracer 进行跟踪时,此属性会由 Tracer.create_proxy 方法填充。为了在调试目的下记录堆栈跟踪,请在 Tracer 实例上设置 record_stack_traces = True。默认情况下,当使用 dynamo 进行跟踪时,此属性将由 OutputGraph.create_proxy 填充。
stack_trace会在字符串的末尾包含最内层的帧信息。
-
- 类torch.fx.Tracer(autowrap_modules=(math,), autowrap_functions=())[源代码]
-
Tracer
是实现torch.fx.symbolic_trace
符号跟踪功能的类。调用symbolic_trace(m)
等同于调用Tracer().trace(m)
。可以通过子类化 Tracer 来覆盖跟踪过程中的各种行为。可以覆盖的行为在该类方法的文档字符串中有详细说明。
注意
此 API 的向后兼容性得到了保证。
- call_module(m, forward, args, kwargs)[源代码]
-
定义此
Tracer
在遇到nn.Module
实例的调用时应如何处理的方法。默认情况下,会通过
is_leaf_module
检查被调用的模块是否为叶模块。如果是,则在Graph
中生成一个引用m
的call_module
节点。否则,正常调用该Module
并跟踪其forward
函数中的操作。你可以重写此方法来创建嵌套的追踪 GraphModules,或在跨越
Module
边界时实现其他所需的行为。- 参数
-
-
m (Module) – 要进行调用的模块
-
forward (Callable) –
Module
中要调用的 forward() 方法 -
args (元组) – 模块调用位置的参数
-
kwargs (Dict) – 模块调用位置的 kwargs 参数
-
- 返回值
-
模块调用的返回值。如果是
call_module
节点发出的情况,那么返回的是一个Proxy
值。否则,返回的就是从Module
调用中得到的实际值。 - 返回类型
注意
此 API 的向后兼容性得到了保证。
- create_arg(a)[源代码]
-
一种指定在为
图
中节点参数准备值时进行追踪行为的方法。默认的行为包括:
-
遍历各种集合类型(如元组、列表、字典),并对其内的元素递归调用
create_args
。 -
给定一个Proxy对象,返回底层IR
Node
的引用。 -
对于给定的非代理张量对象,为不同情形生成IR代码:
-
对于一个参数,生成一个指向该参数的
get_attr
节点 -
对于非参数张量,将其存储在一个特殊的属性中,并引用该属性。
-
可以重写此方法以支持更多类型的对象。
- 参数
-
a (任意) – 在
Graph
中作为Argument
发射的值。 - 返回值
-
将值
a
转换为适当的Argument
- 返回类型
-
参数
注意
此 API 的向后兼容性得到了保证。
-
- create_args_for_root(root_fn, is_module, concrete_args=None)[源代码]
-
创建与
root
模块签名对应的placeholder
节点。此方法会分析根模块的签名并相应地生成这些节点,同时也支持*args
和**kwargs
。警告
此 API 是实验性的,且不支持向下兼容。
- create_node(kind, target, args, kwargs, name=None, type_expr=None)
-
根据目标、参数 args、关键字参数 kwargs 和名称来插入一个图节点。
可以重写此方法来执行额外的检查、验证或修改用于节点创建的值。例如,可能希望禁止记录就地操作。
注意
此 API 的向后兼容性得到了保证。
- 返回类型
- create_proxy(kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None)
-
根据给定的参数创建一个节点,然后将其封装在代理对象中返回。
如果 kind = ‘placeholder’,那么我们创建一个表示函数参数的 Node。如果我们需要为默认参数编码,则使用
args
元组。args
对于placeholder
Nodes 而言则为空。注意
此 API 的向后兼容性得到了保证。
- getattr(attr, attr_val, parameter_proxy_cache)[源代码]
-
定义此
Tracer
在调用nn.Module
实例时处理getattr
方法的行为。默认情况下,会为属性返回一个代理值。此外,它还会将该代理值存储在
parameter_proxy_cache
中,这样未来的调用可以重复使用这个代理,而无需重新创建一个新的。此方法可以被重写,以便在查询参数时不要返回代理。
- 参数
- 返回值
-
getattr 调用返回的值。
警告
此 API 是实验性的,且不支持向下兼容。
- is_leaf_module(m, module_qualified_name)[源代码]
-
一种方法,用于指定给定的
nn.Module
是否为“叶子”模块。叶模块是出现在IR中的基本单元,通过
call_module
调用进行引用。默认情况下,PyTorch标准库命名空间(torch.nn)中的模块都是叶模块。除非另有指定,否则所有其他模块都将被追踪,并记录其内部的操作。- 参数
- 返回类型
注意
此 API 的向后兼容性得到了保证。
- iter(obj)
-
- 在迭代代理对象时被调用,例如:
-
在控制流中使用时,通常我们不知道该如何操作,因为不知道代理对象的值。但是,一个自定义追踪器可以使用 create_node 向图节点添加更多详情,并且可以选择返回一个迭代器。
注意
此 API 的向后兼容性得到了保证。
- 返回类型
- keys(obj)
-
- 当代理对象调用 keys() 方法时触发。
-
这是在代理上调用 ** 时发生的情况。它应该返回一个迭代器,并且 ** 应该在你的自定义追踪器中正常工作。
注意
此 API 的向后兼容性得到了保证。
- 返回类型
- path_of_module(mod)[源代码]
-
这是一个辅助方法,用于在
root
的模块层次结构中查找mod
的全名。例如,如果root
有一个名为foo
的子模块,并且该foo
子模块还有一个名为bar
的子模块,则将bar
传递给此函数将会返回字符串“foo.bar”。注意
此 API 的向后兼容性得到了保证。
- 代理(节点)
-
注意
此 API 的向后兼容性得到了保证。
- 返回类型
- to_bool(obj)
-
- 当将代理对象转换为布尔值时调用,例如
-
在控制流中使用时,通常我们不知道该如何操作,因为不知道代理对象的值。但是,一个自定义追踪器可以使用 create_node 向图节点添加更多相关信息,并且可以选择返回一个具体的值。
注意
此 API 的向后兼容性得到了保证。
- 返回类型
- trace(root, concrete_args=None)[源代码]
-
追踪
root
并返回对应的 FXGraph
表示。这里的root
可以是一个nn.Module
实例或一个 Python 函数。注意,在此调用之后,
self.root
可能与传入的root
不同。例如,当一个自由函数被传递给trace()
时,我们会创建一个nn.Module
实例作为根节点,并添加嵌入常量。- 参数
- 返回值
-
表示传入的
root
语义的Graph
。 - 返回类型
注意
此 API 的向后兼容性得到了保证。
- 类torch.fx.Proxy(node, tracer=None)[源代码]
-
Proxy
对象是Node
的包装器,在符号跟踪过程中流经程序,并将它们接触的所有操作(如torch
函数调用、方法调用和运算符)记录到不断增长的 FX 图中。如果你在进行图变换,可以为原始的
Node
包装一个自定义的Proxy
方法,以便使用重载的操作符向Graph
添加额外的内容。Proxy
对象不能被迭代。换句话说,如果在循环中或作为*args
/**kwargs
函数参数使用Proxy
,符号追踪器会抛出错误。有两种主要的方法来解决这个问题:1. 将不可追踪的逻辑提取到一个顶级函数中,并使用
fx.wrap
包围它。2. 如果控制流是静态的(即循环迭代次数基于某些超参数),代码可以保持在原来的位置并进行重构,使其类似于:for i in range(self.some_hyperparameter): indexed_item = proxied_value[i]
了解更多关于代理内部的详细信息,请参阅torch/fx/README.md中的“Proxy”部分。
注意
此 API 的向后兼容性得到了保证。
- 类torch.fx.Interpreter(module, garbage_collect_values=True, graph=None)[源代码]
-
解释器按节点逐一执行FX图。这种模式很有用,可以用于编写代码转换和进行分析等任务。
可以重写 Interpreter 类中的方法来自定义执行行为。根据调用层次结构,可覆盖的方法如下:
run() +-- run_node +-- placeholder() +-- get_attr() +-- call_function() +-- call_method() +-- call_module() +-- output()
示例
假设我们想将所有
torch.neg
实例替换为torch.sigmoid
,反之亦然(包括它们的Tensor
方法等价物)。我们可以像这样子类化 Interpreter:class NegSigmSwapInterpreter(Interpreter): def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(n) def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any: if target == 'neg': call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(n) def fn(x): return torch.sigmoid(x).neg() gm = torch.fx.symbolic_trace(fn) input = torch.randn(3, 4) result = NegSigmSwapInterpreter(gm).run(input) torch.testing.assert_close(result, torch.neg(input).sigmoid())
- 参数
-
-
module (torch.nn.Module) – 需要执行的模块
-
garbage_collect_values (bool) – 是否在模块执行过程中删除值的最后一次使用后进行清理。这确保了执行过程中的最优内存使用。可以通过禁用此功能来检查所有中间值,例如查看
Interpreter.env
属性。 -
graph (可选[Graph]) – 如果提供了此参数,解释器将使用该图进行执行,而不是module.graph,并利用提供的module来满足状态请求。
-
注意
此 API 的向后兼容性得到了保证。
- boxed_run(args_list)[源代码]
-
通过解释执行模块并返回结果。这采用“带框”调用约定:传递一个参数列表,该列表会被解释器清空,从而确保输入张量能被及时释放。
注意
此 API 的向后兼容性得到了保证。
- call_function(target, args, kwargs)[源代码]
-
执行
call_function
节点并返回结果。- 参数
-
-
target (Target) - 此节点的调用目标。关于语义的详细信息,请参阅Node
-
args (元组) – 本次调用的位置参数组成的元组
-
kwargs (Dict) – 此次调用的关键词参数字典
-
- 返回类型
- 返回值
-
Any: 函数调用返回的值
注意
此 API 的向后兼容性得到了保证。
- call_method(target, args, kwargs)[源代码]
-
执行
call_method
节点并返回结果。- 参数
-
-
target (Target) - 此节点的调用目标。关于语义的详细信息,请参阅Node
-
args (元组) – 本次调用的位置参数组成的元组
-
kwargs (Dict) – 此次调用的关键词参数字典
-
- 返回类型
- 返回值
-
Any:方法调用返回的值
注意
此 API 的向后兼容性得到了保证。
- call_module(target, args, kwargs)[源代码]
-
执行一个
call_module
节点并返回结果。- 参数
-
-
target (Target) - 此节点的调用目标。关于语义的详细信息,请参阅Node
-
args (元组) – 本次调用的位置参数组成的元组
-
kwargs (Dict) – 此次调用的关键词参数字典
-
- 返回类型
- 返回值
-
Any: 模块调用返回的值
注意
此 API 的向后兼容性得到了保证。
- fetch_args_kwargs_from_env(n)[源代码]
-
从当前执行环境获取节点
n
的args
和kwargs
的具体值。- 参数
-
n (Node) – 获取
args
和kwargs
的节点。 - 返回值
-
带有具体值的
args
和kwargs
,用于n
。 - 返回类型
-
元组[元组, 字典]
注意
此 API 的向后兼容性得到了保证。
- fetch_attr(target)[源代码]
-
从
self.module
的Module
层次结构中获取一个属性。- 参数
-
target (str) – 需要获取的目标属性的完整名称
- 返回值
-
该属性的值。
- 返回类型
-
任何
注意
此 API 的向后兼容性得到了保证。
- get_attr(target, args, kwargs)[源代码]
-
执行一个
get_attr
节点,从中检索self.module
的Module
层次结构中的属性值。- 参数
-
-
target (Target) - 此节点的调用目标。关于语义的详细信息,请参阅Node
-
args (元组) – 本次调用的位置参数组成的元组
-
kwargs (Dict) – 此次调用的关键词参数字典
-
- 返回值
-
检索到的属性的值
- 返回类型
-
任何
注意
此 API 的向后兼容性得到了保证。
- map_nodes_to_values(args, n)[源代码]
-
递归遍历
args
,并在当前执行环境中为每个Node
查找具体的值。- 参数
-
-
args (参数) – 用于查找具体值的数据结构
-
n (Node) —
args
所属的节点。这仅用于错误报告。
-
- 返回类型
-
Optional[Union[Tuple[Any, ...], List[Any], Dict[str, Any], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]
注意
此 API 的向后兼容性得到了保证。
- output(target, args, kwargs)[源代码]
-
执行一个
output
节点。实际上是获取该节点引用的值并返回它。- 参数
-
-
target (Target) - 此节点的调用目标。关于语义的详细信息,请参阅Node
-
args (元组) – 本次调用的位置参数组成的元组
-
kwargs (Dict) – 此次调用的关键词参数字典
-
- 返回值
-
输出节点所引用的返回值
- 返回类型
-
任何
注意
此 API 的向后兼容性得到了保证。
- placeholder(target, args, kwargs)[源代码]
-
执行一个
placeholder
节点。需要注意的是,这是有状态的操作:Interpreter
维护了一个内部迭代器,用于遍历传递给run
方法的参数,并且此方法返回该迭代器的下一个值。- 参数
-
-
target (Target) - 此节点的调用目标。关于语义的详细信息,请参阅Node
-
args (元组) – 本次调用的位置参数组成的元组
-
kwargs (Dict) – 此次调用的关键词参数字典
-
- 返回值
-
获取的参数值。
- 返回类型
-
任何
注意
此 API 的向后兼容性得到了保证。
- run(*args, initial_env=None, enable_io_processing=True)[源代码]
-
通过解释运行模块并返回结果。
- 参数
- 返回值
-
执行模块返回的值
- 返回类型
-
任何
注意
此 API 的向后兼容性得到了保证。
- 类torch.fx.Transformer(module)[源代码]
-
Transformer
是一种特殊的解释器,它生成一个新的Module
。它提供了一个transform()
方法来返回转换后的Module
。与需要参数的Interpreter
不同,Transformer
运行时不需要任何参数,并且完全以符号方式进行工作。示例
假设我们想将所有
torch.neg
实例替换为torch.sigmoid
,反之亦然(包括它们的Tensor
方法等价物)。我们可以像这样子类化Transformer
:class NegSigmSwapXformer(Transformer): def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(n) def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: if target == 'neg': call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(n) def fn(x): return torch.sigmoid(x).neg() gm = torch.fx.symbolic_trace(fn) transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform() input = torch.randn(3, 4) torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
- 参数
-
module (GraphModule) - 要进行转换的
Module
。
注意
此 API 的向后兼容性得到了保证。
- get_attr(target, args, kwargs)[源代码]
-
执行一个
get_attr
节点。在Transformer
中,会重写此操作以在输出图中插入一个新的get_attr
节点。- 参数
-
-
target (Target) - 此节点的调用目标。关于语义的详细信息,请参阅Node
-
args (元组) – 本次调用的位置参数组成的元组
-
kwargs (Dict) – 此次调用的关键词参数字典
-
- 返回类型
注意
此 API 的向后兼容性得到了保证。
- placeholder(target, args, kwargs)[源代码]
-
执行一个
placeholder
节点。在Transformer
中,会重新定义此操作以在输出图中插入一个新的placeholder
。- 参数
-
-
target (Target) - 此节点的调用目标。关于语义的详细信息,请参阅Node
-
args (元组) – 本次调用的位置参数组成的元组
-
kwargs (Dict) – 此次调用的关键词参数字典
-
- 返回类型
注意
此 API 的向后兼容性得到了保证。
- transform()[源代码]
-
将
self.module
转换,并返回转换后的GraphModule
。注意
此 API 的向后兼容性得到了保证。
- 返回类型
- torch.fx.replace_pattern(gm, pattern, replacement)[源代码]
-
在图模块(
gm
)的图中,找到所有可能的非重叠操作符及其数据依赖关系集合(pattern
),然后将每个匹配到的子图替换为另一个子图(replacement
)。- 参数
-
-
gm (GraphModule) - 用于包装并操作图的 GraphModule
-
pattern (Union[Callable, GraphModule]) – 要在
gm
中匹配并替换的子图 -
replacement (Union[Callable, GraphModule]) – 用于替换
pattern
的子图
-
- 返回值
-
表示在原始图中
pattern
匹配位置的Match
对象列表。如果不存在匹配项,该列表将为空。Match
定义如下:class Match(NamedTuple): # Node from which the match was found anchor: Node # Maps nodes in the pattern subgraph to nodes in the larger graph nodes_map: Dict[Node, Node]
- 返回类型
-
匹配列表
示例:
import torch from torch.fx import symbolic_trace, subgraph_rewriter class M(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x, w1, w2): m1 = torch.cat([w1, w2]).sum() m2 = torch.cat([w1, w2]).sum() return x + torch.max(m1) + torch.max(m2) def pattern(w1, w2): return torch.cat([w1, w2]).sum() def replacement(w1, w2): return torch.stack([w1, w2]) traced_module = symbolic_trace(M()) subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
上述代码将在
traced_module
的forward
方法中首先匹配pattern
。模式匹配是基于使用定义关系进行的,而不是节点名称。例如,如果你在pattern
中有p = torch.cat([a, b])
,你可以在原始的forward
函数中找到与之对应的代码m = torch.cat([a, b])
,即使变量名不同(p
vsm
)。在
pattern
中,return
语句仅根据其值进行匹配;它可能与更大的图中的return
语句匹配,也可能不匹配。换句话说,模式不必延伸到更大图的末尾。当模式匹配成功时,它将从较大的函数中移除,并被
replacement
替换。如果较大函数中有多个pattern
匹配项,则每个非重叠的匹配项都将被替换。对于重叠的匹配项,将在一组重叠的匹配项中找到的第一个匹配项进行替换。(这里的“第一个”是指根据节点使用定义关系的拓扑排序中的第一个节点。通常情况下,第一个节点是直接出现在self
之后的参数,而最后一个节点则是函数返回的内容。)需要注意的是,
pattern
可调用函数中的参数必须在该可调用函数中使用,而replacement
可调用函数的参数必须与模式匹配。这就是为什么在上述代码块中,forward
函数有参数x, w1, w2
,但pattern
函数只有参数w1, w2
。由于pattern
不使用x
参数,因此不应将其指定为参数。作为第二条规则的一个例子,请考虑替换操作。def pattern(x, y): return torch.neg(x) + torch.relu(y)
使用
def replacement(x, y): return torch.relu(x)
在这种情况下,
replacement
需要与pattern
相同数量的参数(即x
和y
),即使y
参数在replacement
中没有被使用。调用
subgraph_rewriter.replace_pattern
之后,生成的 Python 代码如下所示:def forward(self, x, w1, w2): stack_1 = torch.stack([w1, w2]) sum_1 = stack_1.sum() stack_2 = torch.stack([w1, w2]) sum_2 = stack_2.sum() max_1 = torch.max(sum_1) add_1 = x + max_1 max_2 = torch.max(sum_2) add_2 = add_1 + max_2 return add_2
注意
此 API 的向后兼容性得到了保证。