Torch 导出中间表示规范

导出中间表示(IR)是编译器中的一种形式,与MLIR和TorchScript类似。它专门用于表达PyTorch程序的语义。导出IR主要通过简洁的操作列表来表示计算,并且对动态特性如控制流的支持有限。

要创建一个导出中间表示(IR)图,可以使用前端通过跟踪专业化机制安全地捕获 PyTorch 程序。生成的导出 IR 可以由后端进行优化和执行。这可以通过torch.export.export() 今天实现。

本文将介绍的关键概念包括:

  • ExportedProgram:包含导出IR程序的数据结构

  • 图:包含一个节点列表。

  • 节点:表示存储在此节点上的操作、控制流程和元数据。

  • 值由节点产生和消耗。

  • 类型与值和节点关联。

  • 值的大小和内存布局也已定义。

假设

本文档假设读者熟悉 PyTorch,特别是对 torch.fx 及其相关工具有所了解。因此,将不再重复描述torch.fx文档和论文中已有的内容。

什么是Export IR

导出IR是PyTorch程序的一种基于图的中间表示(IR)。它建立在torch.fx.Graph之上。换句话说,所有的导出IR图都是有效的FX图,并且如果使用标准的FX语义进行解释,则导出IR可以被正确地解释。这意味着,通过标准的FX代码生成,导出的图可以转换为有效的Python程序。

本文档主要侧重于突出Export IR在严格性方面与FX的不同之处,而忽略两者相似的部分。

导出程序

顶级的导出中间表示(IR)构造是一个 torch.export.ExportedProgram 类。它将一个 PyTorch 模型的计算图(通常是 torch.nn.Module)与其使用的参数或权重捆绑在一起。

torch.export.ExportedProgram 类的一些重要属性包括:

  • graph_module (torch.fx.GraphModule):包含PyTorch模型展开计算图的数据结构。可以通过ExportedProgram.graph直接访问该图。

  • graph_signature (torch.export.ExportGraphSignature):图签名,它指定了在图中使用和修改的参数及缓冲区名称。参数和缓冲区不是作为图的属性存储,而是被提升为图的输入。graph_signature 用于跟踪这些参数和缓冲区的相关信息。

  • state_dict (Dict[str, Union[torch.Tensor, torch.nn.Parameter]]):一个包含模型参数和缓冲区的数据结构。

  • range_constraints (Dict[sympy.Symbol, RangeConstraint]):对于导出时具有数据依赖行为的程序,每个节点的元数据包含类似 s0i0 的符号形状。此属性将这些符号形状映射到它们的下限和上限范围。

图形

导出的IR图是以有向无环图(DAG)形式表示的一个PyTorch程序。图中的每个节点代表一个特定的计算或操作,而图的边则是通过节点之间的引用来组成的。

我们可以查看具有该模式的图形:

class Graph:
  nodes: List[Node]

实际上,Export IR 的图是通过 torch.fx.Graph Python 类实现的。

一个导出的IR图包含以下节点(各节点的具体内容将在下一节中详细介绍):

  • 0 个或多个类型为 占位符 的节点

  • 0 个或多个类型为 call_function 的节点

  • 正好有一个类型为 output 的节点

推论: 最小的有效图包含一个节点。也就是说,节点集合永远不会为空。

定义:图(Graph)中的占位符节点集合代表图模块(GraphModule)的输入。图(Graph)的输出节点表示图模块(GraphModule)的输出

示例:

from torch import nn

class MyModule(nn.Module):

    def forward(self, x, y):
      return x + y

mod = torch.export.export(MyModule())
print(mod.graph)

上述内容是以文本形式表示的一个图,每行代表一个节点。

节点

A Node 表示特定的计算或操作,并使用 torch.fx.Node 类在 Python 中进行表示。节点之间的边通过 Node 类中的 args 属性直接引用其他节点来表示。利用相同的 FX 机制,我们可以表示计算图通常需要的操作,例如操作符调用、占位符(即输入)、条件语句和循环。

节点采用以下结构:
class Node:
  name: str # name of node
  op_name: str  # type of operation

  # interpretation of the fields below depends on op_name
  target: [str|Callable]
  args: List[object]
  kwargs: Dict[str, object]
  meta: Dict[str, object]

FX文本格式

如上面的示例所示,注意每行都采用了以下格式:

%<name>:[...] = <op_name>[target=<target>](args = (%arg1, %arg2, arg3, arg4, …)), kwargs = {"keyword": arg5})

此格式以紧凑的方式包含了 Node 类中的所有内容,除了 meta 属性。

具体而言:

  • <name> 是节点在 node.name 中显示的名称。

  • <op_name>node.op 字段的值,该字段必须是以下之一:<call_function><placeholder><get_attr><output>

  • <target> 表示节点的目标,具体为 node.target 的值。该字段的具体意义由 op_name 决定。

  • args1, … args 4…node.args 元组中的内容。如果元组中的某个值是torch.fx.Node对象,则会在其前面特别加上一个%.

例如,调用加法运算符将表现为:

%add1 = call_function[target = torch.op.aten.add.Tensor](args = (%x, %y), kwargs = {})

其中 %x%y 是两个名为 x 和 y 的节点。值得注意的是,字符串 torch.op.aten.add.Tensor 表示实际存储在目标字段中的可调用对象本身,而不仅仅是它的名称。

此文本格式的最后一行是:

return [add]

这是一个具有op_name = output属性的节点,表示我们将返回这个单一元素。

调用函数

call_function 节点表示对操作符的调用。

定义

  • 功能性:如果一个可调用对象满足以下所有要求,则称其为“功能性”的:

    • 非修改操作:该操作不更改其输入的值(对于张量,这包括元数据和实际数据)。

    • 无副作用:操作符不会改变外部可见的状态,例如不修改模块参数的值。

  • 运算符:是一个具有预定义模式的函数式调用。这类运算符的例子包括功能性的ATen运算符。

外汇表示

%name = call_function[target = operator](args = (%x, %y, …), kwargs = {})

与标准 FX call_function 的区别

  1. 在FX图中,call_function可以引用任何可调用对象。但在导出IR时,我们只允许它使用ATen操作符、自定义操作符和控制流操作符的特定子集。

  2. 在导出IR时,常量参数将会被嵌入到图形中。

  3. 在FX图中,get_attr节点可以表示读取图模块中存储的任意属性。但在导出IR时,只能读取子模块中的属性,因为所有的参数和缓冲区都会被作为输入传递到图模块中。

元数据

Node.meta 是附加到每个 FX 节点的一个字典。然而,FX 规范并没有规定可以或将会存在什么样的元数据。Export IR 提供了一个更强的约定:所有 call_function 节点都将保证具有且仅具有以下元数据字段。

  • node.meta["stack_trace"] 是一个包含引用原始 Python 源代码的 Python 堆栈跟踪的字符串。下面是一个堆栈跟踪的例子:

    File "my_module.py", line 19, in forward
    return x + dummy_helper(y)
    File "helper_utility.py", line 89, in dummy_helper
    return y + 1
    
  • node.meta["val"] 描述了操作运行的输出。它可以是类型 <symint><FakeTensor>,也可以是一个包含 List[Union[FakeTensor, SymInt]] 的列表,或者为 None

  • node.meta["nn_module_stack"] 描述了节点来源的 torch.nn.Module 的“调用栈”。如果该节点来自一个 torch.nn.Module 调用,例如,从一个位于 torch.nn.Sequential 模块内的 torch.nn.Linear 模块中调用的包含 addmm 操作的节点,那么 nn_module_stack 会看起来像这样:

    {'self_linear': ('self.linear', <class 'torch.nn.Linear'>), 'self_sequential': ('self.sequential', <class 'torch.nn.Sequential'>)}
    
  • node.meta["source_fn_stack"] 包含了此节点在分解之前调用的 torch 函数或叶 torch.nn.Module 类。例如,一个包含来自 torch.nn.Linear 模块调用的 addmm 操作的节点,在其 source_fn 中会包含 torch.nn.Linear;而一个包含来自 torch.nn.functional.Linear 模块调用的 addmm 操作的节点,在其 source_fn 中会包含 torch.nn.functional.Linear

占位符

占位符代表图的输入,其语义与FX中的完全一致。占位符节点必须是图的节点列表中的前N个节点,其中N可以为零。

外汇表示

%name = placeholder[target = name](args = ())

目标字段是一个字符串,表示输入的名称。

如果 args 不为空,则其大小应为 1,并表示该输入的默认值。

元数据

占位符节点也有像 call_function 节点一样的 meta[‘val’]。在这种情况下,val 字段表示图期望接收的此输入参数的形状和数据类型。

输出

输出调用代表函数中的返回语句,因此会终止当前的图。图中只有一个输出节点,并且它总是位于图的最后。

外汇表示

output[](args = (%something, …))

这与torch.fx 中的语义完全一致。其中,args 表示需要返回的节点。

元数据

输出节点和call_function节点具有相同的元数据。

获取属性

get_attr 节点表示从封装的 torch.fx.GraphModule 中读取子模块。与来自 torch.fx.symbolic_trace() 的普通 FX 图不同,后者使用 get_attr 节点从顶级 torch.fx.GraphModule 读取属性(如参数和缓冲区),在导出的程序中,这些参数和缓冲区作为图模块的输入传递,并存储在顶级 torch.export.ExportedProgram 中。

外汇表示

%name = get_attr[target = name](args = ())

示例

考虑以下模型:

from functorch.experimental.control_flow import cond

def true_fn(x):
    return x.sin()

def false_fn(x):
    return x.cos()

def f(x, y):
    return cond(y, true_fn, false_fn, [x])

图表:

graph():
    %x_1 : [num_users=1] = placeholder[target=x_1]
    %y_1 : [num_users=1] = placeholder[target=y_1]
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %conditional : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%y_1, %true_graph_0, %false_graph_0, [%x_1]), kwargs = {})
    return conditional

该行代码 %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0] 读取子模块 true_graph_0,其中包含 sin 操作符。

参考资料

SymInt

A SymInt 是一个对象,它可以是一个具体的整数值或代表整数的符号(在 Python 中由 sympy.Symbol 类表示)。当 SymInt 作为符号时,它描述了一个类型为整数但在编译阶段图中未知的变量,也就是说,它的具体值仅在运行时才能确定。

FakeTensor

FakeTensor 是一个包含张量元数据的对象,可以视为具有以下属性:

class FakeTensor:
  size: List[SymInt]
  dtype: torch.dtype
  device: torch.device
  dim_order: List[int]  # This doesn't exist yet

FakeTensor 的 size 字段是一个由整数或 SymInt 组成的列表。如果存在 SymInt,说明张量具有动态形状;若只有整数,则表示张量具有固定的静态形状。TensorMeta 的秩始终是固定的。dtype 字段表示该节点输出的数据类型。Edge IR 中没有隐式的类型提升机制。FakeTensor 不包含 strides。

换句话说:

  • 如果节点的目标操作返回一个张量,那么 node.meta['val'] 就是描述这个张量的一个 FakeTensor。

  • 如果节点中的操作符返回一个包含 n 个张量的元组,那么 node.meta['val'] 就会是一个包含每个张量的 FakeTensor 元组。

  • 如果节点中的操作符返回一个编译时已知的整数、浮点数或标量值,则 node.meta['val'] 为 None。

  • 如果节点中的操作符返回一个在编译时未知的整数、浮点数或标量值,则 node.meta['val'] 的类型为 SymInt。

例如:

  • aten::add 返回一个张量,因此它的规格将是一个假张量,具有此操作返回的张量的数据类型和大小。

  • aten::sym_size 返回一个整数,因此它的值会是一个 SymInt,因为在运行时才能确定其具体值。

  • max_pool2d_with_indexes 返回一个包含两个张量的元组 (Tensor, Tensor);因此,规格将是一个包含两个 FakeTensor 对象的 2 元组,其中第一个 TensorMeta 描述返回值的第一个元素等。

Python代码:

def add_one(x):
  return torch.ops.aten(x, 1)

图表:

graph():
  %ph_0 : [#users=1] = placeholder[target=ph_0]
  %add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, 1), kwargs = {})
  return [add_tensor]

FakeTensor:

FakeTensor(dtype=torch.int, size=[2,], device=CPU)

可以转换为Pytree的类型

我们将一种类型定义为“可构成Pytree的类型”,如果它要么是叶子类型,要么是包含其他可构成Pytree类型的容器类型。

注意:

pytree 的概念与 JAX 文档中的定义相同:

以下类型被定义为叶类型

类型

定义

张量

torch.Tensor

标量

包括整数类型、浮点类型和零维张量在内的任何 Python 数值类型。

整型

Python 中的 int(在 C++ 中绑定为 int64_t)

浮点数

Python中的float类型(在C++中对应为double类型)

布尔值

Python 布尔类型

str

Python字符串

标量类型

torch.dtype

布局

torch.layout

内存格式

torch.memory_format

设备

torch.device

以下类型被定义为容器类型

类型

定义

元组

Python 元组

列表

Python列表

字典

具有标量键的 Python 字典

命名元组(NamedTuple)

Python 命名元组

数据类

必须通过register_dataclass 进行注册

自定义类

通过_register_pytree_node 定义的任何自定义类

本页目录