torch.export

警告

此功能目前处于积极开发阶段的原型,未来可能会有重大变化。

概述

torch.export.export() 接受任意 Python 可调用对象(如 torch.nn.Module、函数或方法),并以 Ahead-of-Time (AOT) 方式生成一个仅表示张量计算的追踪图。此追踪图可以随后使用不同的输入执行,或者进行序列化。

import torch
from torch.export import export

class Mod(torch.nn.Module):
    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        a = torch.sin(x)
        b = torch.cos(y)
        return a + b

example_args = (torch.randn(10, 10), torch.randn(10, 10))

exported_program: torch.export.ExportedProgram = export(
    Mod(), args=example_args
)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[10, 10], arg1_1: f32[10, 10]):
            # code: a = torch.sin(x)
            sin: f32[10, 10] = torch.ops.aten.sin.default(arg0_1);

            # code: b = torch.cos(y)
            cos: f32[10, 10] = torch.ops.aten.cos.default(arg1_1);

            # code: return a + b
            add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos);
            return (add,)

    Graph signature: ExportGraphSignature(
        parameters=[],
        buffers=[],
        user_inputs=['arg0_1', 'arg1_1'],
        user_outputs=['add'],
        inputs_to_parameters={},
        inputs_to_buffers={},
        buffers_to_mutate={},
        backward_signature=None,
        assertion_dep_token=None,
    )
    Range constraints: {}

torch.export 生成一个干净的中间表示(IR),并保持以下不变性。有关 IR 的更多信息,请参见此处规范

  • 正确性:它确保是对原始程序的准确表示,并保留了原始程序的调用约定。

  • 标准化:图中不包含Python语义,来自原始程序的子模块被内联,形成了一个完全展平的计算图。

  • 图属性:该图是纯函数式的,不包含任何具有副作用的操作(如变异或别名)。它不会修改任何中间值、参数或缓冲区。

  • 元数据: 该图包含了在追踪过程中捕获的元数据,例如用户代码中的堆栈跟踪信息。

在底层,torch.export 使用了以下最新技术:

  • TorchDynamo (torch._dynamo) 是一个内部API,它使用CPython的Frame Evaluation API功能来安全地追踪PyTorch图。这提供了极大的改进的图捕获体验,并且为了完全追踪PyTorch代码所需的重写大大减少。

  • AOT Autograd 提供一个功能化的 PyTorch 图,并确保该图被分解或降级为 ATen 操作集。

  • Torch FX (torch.fx) 是图的底层表示,支持灵活的基于 Python 的变换。

现有框架

torch.compile() 也使用与 torch.export 相同的 PT2 堆栈,但有一些差异:

  • JIT 与 AOT: torch.compile() 是一个 JIT 编译器,不用于在部署之外生成编译后的工件。

  • 部分图捕获与完整图捕获: 当torch.compile()遇到无法追踪的模型部分时,它会“图中断”并回退到即时 Python 运行环境中执行程序。相比之下,torch.export旨在获取 PyTorch 模型的完整图表示,因此当遇到不可追踪的部分时会报错。由于 torch.export 生成的是一个与任何 Python 特性或运行环境完全分离的完整图,这个图可以在不同的环境和语言中保存、加载并执行。

  • 可用性权衡: 由于 torch.compile() 能够在遇到无法追踪的代码时回退到 Python 运行时,因此它更加灵活。而 torch.export 则需要用户提供更多信息或重写代码以使其可被追踪。

torch.fx.symbolic_trace()相比,torch.export 使用在 Python 字节码级别操作的 TorchDynamo 进行跟踪,这使得它可以追踪任意的 Python 构造,而不受 Python 操作符重载支持范围的限制。此外,torch.export 会详细记录张量元数据,因此基于张量形状等条件判断不会导致跟踪失败。总体而言,torch.export 预期可以适用于更多的用户程序,并生成更低级别的图(在 torch.ops.aten 操作符级别)。需要注意的是,用户仍然可以在使用 torch.export 之前采用torch.fx.symbolic_trace()作为预处理步骤。

更自然的版本:

torch.fx.symbolic_trace()相比,torch.export 使用 TorchDynamo 在 Python 字节码级别进行跟踪。这使得它可以追踪任意的 Python 构造,而不受 Python 操作符重载支持范围的限制。此外,torch.export 会详细记录张量元数据,因此基于张量形状等条件判断不会导致跟踪失败。总体而言,torch.export 预期可以适用于更多的用户程序,并生成更低级别的图(在 torch.ops.aten 操作符级别)。需要注意的是,在使用 torch.export 之前,用户仍然可以采用torch.fx.symbolic_trace()作为预处理步骤。

torch.jit.script()相比,torch.export不会捕获Python的控制流或数据结构,但它支持比TorchScript更多的Python语言特性(因为它更容易全面覆盖Python字节码)。生成的图更简单,并且只包含直线控制流(除了显式的控制流操作符)。

torch.jit.trace()相比,torch.export更为可靠:它能够追踪进行整数计算以确定大小的代码,并记录所有必要的条件,确保特定的跟踪对其他输入也有效。

导出 PyTorch 模型

示例

主要的入口是通过 torch.export.export(),它接受一个可调用对象(如 torch.nn.Module、函数或方法)和样本输入,并将计算图捕获到一个 torch.export.ExportedProgram 中。例如:

import torch
from torch.export import export

# Simple module for demonstration
class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(
            in_channels=3, out_channels=16, kernel_size=3, padding=1
        )
        self.relu = torch.nn.ReLU()
        self.maxpool = torch.nn.MaxPool2d(kernel_size=3)

    def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor:
        a = self.conv(x)
        a.add_(constant)
        return self.maxpool(self.relu(a))

example_args = (torch.randn(1, 3, 256, 256),)
example_kwargs = {"constant": torch.ones(1, 16, 256, 256)}

exported_program: torch.export.ExportedProgram = export(
    M(), args=example_args, kwargs=example_kwargs
)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256], arg3_1: f32[1, 16, 256, 256]):

            # code: a = self.conv(x)
            convolution: f32[1, 16, 256, 256] = torch.ops.aten.convolution.default(
                arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1
            );

            # code: a.add_(constant)
            add: f32[1, 16, 256, 256] = torch.ops.aten.add.Tensor(convolution, arg3_1);

            # code: return self.maxpool(self.relu(a))
            relu: f32[1, 16, 256, 256] = torch.ops.aten.relu.default(add);
            max_pool2d_with_indices = torch.ops.aten.max_pool2d_with_indices.default(
                relu, [3, 3], [3, 3]
            );
            getitem: f32[1, 16, 85, 85] = max_pool2d_with_indices[0];
            return (getitem,)

    Graph signature: ExportGraphSignature(
        parameters=['L__self___conv.weight', 'L__self___conv.bias'],
        buffers=[],
        user_inputs=['arg2_1', 'arg3_1'],
        user_outputs=['getitem'],
        inputs_to_parameters={
            'arg0_1': 'L__self___conv.weight',
            'arg1_1': 'L__self___conv.bias',
        },
        inputs_to_buffers={},
        buffers_to_mutate={},
        backward_signature=None,
        assertion_dep_token=None,
    )
    Range constraints: {}

查看ExportedProgram,我们可以注意到以下几点:

  • torch.fx.Graph 包含了原始程序的计算图,并记录下原始代码以方便调试。

  • 该图仅包含在此处找到的 torch.ops.aten 操作符和自定义操作符,并且是完全功能性的,不包括任何如 torch.add_ 之类的原地操作符。

  • 参数(包括卷积的权重和偏置)被提升为图的输入,因此图中不再包含之前由torch.fx.symbolic_trace()生成的get_attr节点。

  • The torch.export.ExportGraphSignature 定义了输入和输出的签名,并指定了哪些输入是参数。

  • 图中每个节点生成的张量的形状和数据类型会被记录下来。例如,convolution 节点将产生一个数据类型为 torch.float32 且形状为 (1, 16, 256, 256) 的张量。

非严格模式下的导出

在 PyTorch 2.3 中,我们引入了一种新的追踪模式,称为 非严格模式。该模式仍在完善中,因此如果你遇到任何问题,请通过 Github 提交问题,并加上“oncall: export”标签。

非严格模式下,我们使用Python解释器来跟踪程序的执行。你的代码将完全像在急切模式下一样运行;唯一的区别是所有的Tensor对象都将被替换为ProxyTensors,这些ProxyTensors会记录它们的所有操作到一个图中。

严格模式下(当前默认设置),我们首先使用TorchDynamo(一个字节码分析引擎)对程序进行跟踪。TorchDynamo不会执行你的Python代码,而是对其进行符号分析,并基于结果构建图。这种分析使torch.export能够提供更强的安全性保证,但并非所有的Python代码都受支持。

一个可能需要使用非严格模式的例子是,当你遇到 TorchDynamo 不支持的功能并且难以解决时,而你又知道 Python 代码并非完全用于计算。例如:

import contextlib
import torch

class ContextManager():
    def __init__(self):
        self.count = 0
    def __enter__(self):
        self.count += 1
    def __exit__(self, exc_type, exc_value, traceback):
        self.count -= 1

class M(torch.nn.Module):
    def forward(self, x):
        with ContextManager():
            return x.sin() + x.cos()

export(M(), (torch.ones(3, 3),), strict=False)  # Non-strict traces successfully
export(M(), (torch.ones(3, 3),))  # Strict mode fails with torch._dynamo.exc.Unsupported: ContextManager

在这个例子中,第一次使用非严格模式(通过strict=False标志)成功执行了追踪,而第二次使用默认的严格模式则失败了,因为TorchDynamo无法支持上下文管理器。一种解决方法是重写代码(参见torch.export 的限制),但由于上下文管理器不影响模型中的张量计算,我们可以选择接受非严格模式的结果。

表达动态性

默认情况下,torch.export 假设所有输入形状都是静态的并为此导出程序进行专门化。然而,某些维度(如批量维度)可以是动态变化的,并且每次运行时都可能不同。必须使用 torch.export.Dim() API 创建这些动态维度并通过torch.export.export()dynamic_shapes 参数传递它们。例如:

import torch
from torch.export import Dim, export

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.branch1 = torch.nn.Sequential(
            torch.nn.Linear(64, 32), torch.nn.ReLU()
        )
        self.branch2 = torch.nn.Sequential(
            torch.nn.Linear(128, 64), torch.nn.ReLU()
        )
        self.buffer = torch.ones(32)

    def forward(self, x1, x2):
        out1 = self.branch1(x1)
        out2 = self.branch2(x2)
        return (out1 + self.buffer, out2)

example_args = (torch.randn(32, 64), torch.randn(32, 128))

# Create a dynamic batch size
batch = Dim("batch")
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}

exported_program: torch.export.ExportedProgram = export(
    M(), args=example_args, dynamic_shapes=dynamic_shapes
)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[32, 64], arg1_1: f32[32], arg2_1: f32[64, 128], arg3_1: f32[64], arg4_1: f32[32], arg5_1: f32[s0, 64], arg6_1: f32[s0, 128]):

            # code: out1 = self.branch1(x1)
            permute: f32[64, 32] = torch.ops.aten.permute.default(arg0_1, [1, 0]);
            addmm: f32[s0, 32] = torch.ops.aten.addmm.default(arg1_1, arg5_1, permute);
            relu: f32[s0, 32] = torch.ops.aten.relu.default(addmm);

            # code: out2 = self.branch2(x2)
            permute_1: f32[128, 64] = torch.ops.aten.permute.default(arg2_1, [1, 0]);
            addmm_1: f32[s0, 64] = torch.ops.aten.addmm.default(arg3_1, arg6_1, permute_1);
            relu_1: f32[s0, 64] = torch.ops.aten.relu.default(addmm_1);  addmm_1 = None

            # code: return (out1 + self.buffer, out2)
            add: f32[s0, 32] = torch.ops.aten.add.Tensor(relu, arg4_1);
            return (add, relu_1)

    Graph signature: ExportGraphSignature(
        parameters=[
            'branch1.0.weight',
            'branch1.0.bias',
            'branch2.0.weight',
            'branch2.0.bias',
        ],
        buffers=['L__self___buffer'],
        user_inputs=['arg5_1', 'arg6_1'],
        user_outputs=['add', 'relu_1'],
        inputs_to_parameters={
            'arg0_1': 'branch1.0.weight',
            'arg1_1': 'branch1.0.bias',
            'arg2_1': 'branch2.0.weight',
            'arg3_1': 'branch2.0.bias',
        },
        inputs_to_buffers={'arg4_1': 'L__self___buffer'},
        buffers_to_mutate={},
        backward_signature=None,
        assertion_dep_token=None,
    )
    Range constraints: {s0: RangeConstraint(min_val=2, max_val=9223372036854775806)}

其他注意事项:

  • 通过 torch.export.Dim() API 和 dynamic_shapes 参数,我们指定了每个输入的第一个维度为动态。查看输入 arg5_1arg6_1,它们的符号形状分别为 (s0, 64) 和 (s0, 128),而不是我们作为示例传递的 (32, 64) 和 (32, 128) 形状的张量。其中 s0 是一个表示该维度可以取一系列值的符号。

  • exported_program.range_constraints 描述了图中每个符号的取值范围。在这种情况下,我们看到s0 的范围是 [2, inf]。由于一些难以在这里解释的技术原因,它们被假定不为 0 或 1。这不是一个错误,并不一定意味着导出的程序在维度为 0 或 1 时无法正常工作。有关此主题的深入讨论,请参阅The 0/1 Specialization Problem

我们还可以指定输入形状之间的更复杂的关系,例如,两个形状可能相差一个维度,一个形状可能是另一个的两倍,或者一个形状的维度是偶数。例如:

class M(torch.nn.Module):
    def forward(self, x, y):
        return x + y[1:]

x, y = torch.randn(5), torch.randn(6)
dimx = torch.export.Dim("dimx", min=3, max=6)
dimy = dimx + 1

exported_program = torch.export.export(
    M(), (x, y), dynamic_shapes=({0: dimx}, {0: dimy}),
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
    def forward(self, arg0_1: "f32[s0]", arg1_1: "f32[s0 + 1]"):
        # code: return x + y[1:]
        slice_1: "f32[s0]" = torch.ops.aten.slice.Tensor(arg1_1, 0, 1, 9223372036854775807);  arg1_1 = None
        add: "f32[s0]" = torch.ops.aten.add.Tensor(arg0_1, slice_1);  arg0_1 = slice_1 = None
        return (add,)

Graph signature: ExportGraphSignature(
    input_specs=[
        InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None),
        InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)
    ],
    output_specs=[
        OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)]
)
Range constraints: {s0: ValueRanges(lower=3, upper=6, is_bool=False), s0 + 1: ValueRanges(lower=4, upper=7, is_bool=False)}

需要注意的几点:

  • 通过为第一个输入指定{0: dimx},我们看到其形状现在是动态的,即[s0]。接着,通过为第二个输入指定{0: dimy},我们发现它的形状也是动态的。然而,由于表示了dimy = dimx + 1的关系,arg1_1不再使用新的符号,而是与arg0_1中的s0相同。因此,关系dimy = dimx + 1通过s0 + 1来表示。

  • 查看范围约束,我们看到s0的初始范围是[3, 6],并且可以得出s0 + 1的解算范围为[4, 7]。

序列化

要保存ExportedProgram,用户可以使用torch.export.save()torch.export.load() API。通常会使用.pt2文件扩展名来保存ExportedProgram

示例:

import torch
import io

class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10

exported_program = torch.export.export(MyModule(), torch.randn(5))

torch.export.save(exported_program, 'exported_program.pt2')
saved_exported_program = torch.export.load('exported_program.pt2')

专业领域

理解 torch.export 行为的关键在于区分 静态 值和 动态 值。

一个动态值是指在每次运行时可能会发生变化的值。这些值的行为就像传递给 Python 函数的正常参数一样——你可以为参数传入不同的值,并期望函数能够正确处理。Tensor 数据被视为动态值。

一个静态值是在导出时确定的,并且在导出程序的不同执行间不会发生变化。当在跟踪过程中遇到这个值时,导出器会将其作为常量处理,并直接嵌入到图中。

当执行一个操作(例如 x + y)且所有输入都是静态的,该操作的结果会直接硬编码到图中,并且这个操作不会显示出来(即它会被常量折叠)。

当一个值被硬编码到图中时,我们说该图已经针对这个值进行了专门化

以下值为静态:

输入 tensor 形状

默认情况下,torch.export 会根据输入张量的形状来跟踪程序,除非通过 dynamic_shapes 参数将维度指定为动态。这意味着如果存在依赖于形状的控制流,torch.export 将针对给定样本输入所选择的分支进行专业化处理。例如:

import torch
from torch.export import export

class Mod(torch.nn.Module):
    def forward(self, x):
        if x.shape[0] > 5:
            return x + 1
        else:
            return x - 1

example_inputs = (torch.rand(10, 2),)
exported_program = export(Mod(), example_inputs)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[10, 2]):
            add: f32[10, 2] = torch.ops.aten.add.Tensor(arg0_1, 1);
            return (add,)

条件表达式 (x.shape[0] > 5) 不出现在 ExportedProgram 中,因为示例输入具有静态形状 (10, 2)。由于 torch.export 根据输入的静态形状进行特化处理,else 分支 (x - 1) 永远不会被执行。为了在跟踪图中保留基于张量形状的动态分支行为,需要使用 torch.export.Dim() 来指定输入张量 (x.shape[0]) 的维度为动态,并且源代码需要重写

请注意,作为模块状态一部分的张量(如参数和缓冲区)总是具有静态形状。

Python基本数据类型

torch.export 专门处理 Python 中的原始类型,例如 intfloatboolstr。此外,还有动态变体如 SymIntSymFloatSymBool

例如:

import torch
from torch.export import export

class Mod(torch.nn.Module):
    def forward(self, x: torch.Tensor, const: int, times: int):
        for i in range(times):
            x = x + const
        return x

example_inputs = (torch.rand(2, 2), 1, 3)
exported_program = export(Mod(), example_inputs)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[2, 2], arg1_1, arg2_1):
            add: f32[2, 2] = torch.ops.aten.add.Tensor(arg0_1, 1);
            add_1: f32[2, 2] = torch.ops.aten.add.Tensor(add, 1);
            add_2: f32[2, 2] = torch.ops.aten.add.Tensor(add_1, 1);
            return (add_2,)

由于整数是专门化的,torch.ops.aten.add.Tensor 操作都使用硬编码的常量 1 进行计算,而不是用户传递的参数 arg1_1。如果在运行时传入不同的值(例如 2),而导出时使用的值是 1,则会导致错误。此外,在 for 循环中使用的 times 迭代器也通过重复的三个 torch.ops.aten.add.Tensor 调用在图中进行了“内联”,并且输入参数 arg2_1 从未被使用。

Python容器

Python 的容器(如 ListDictNamedTuple 等)被认为是具有静态结构的。

torch.export 的局限性

图表中断

由于 torch.export 是一个一次性过程,用于从 PyTorch 程序中捕获计算图,它可能会遇到无法追踪的程序部分,因为几乎不可能支持所有 PyTorch 和 Python 特性的跟踪。在 torch.compile 的情况下,不支持的操作会导致“图中断”,并以默认的 Python 评估方式运行该操作。相比之下,torch.export 要求用户提供额外的信息或重写部分代码使其可追踪。由于跟踪基于 TorchDynamo,在字节码级别进行评估,因此与之前的跟踪框架相比,所需的重写将显著减少。

当遇到图中断时,ExportDB 是一个很好的资源,可以了解哪些程序是受支持或不受支持的,并提供如何重写程序以使其可追踪的方法。

一个解决这个问题的方法是使用非严格导出

数据/形状相关的控制流

在形状未被专门化的情况下,数据依赖的控制流(如if x.shape[0] > 2)可能会遇到图中断问题,因为追踪编译器无法处理组合爆炸数量的路径而不会生成代码。在这种情况下,用户需要使用特殊的控制流操作符重写他们的代码。目前我们支持torch.cond来表达类似if-else的控制流(更多功能即将推出!)。

操作符缺少假函数、元数据和抽象内核

在进行追踪时,所有操作符都需要一个 FakeTensor 内核(也称为元内核或抽象实现),以确定该操作符的输入和输出形状。

请参见torch.library.register_fake()以获取更多详情。

如果您的模型使用了还未实现 FakeTensor 内核的 ATen 操作符,请在遇到这种情况时提交一个问题。

API参考

torch.export.export(mod, args, kwargs=None, *, dynamic_shapes=None, strict=True, preserve_module_call_signature=())[源代码]

export() 接受任意的 Python 可调用对象(如 nn.Module、函数或方法)以及示例输入,并生成一个仅表示张量计算的追踪图,以编译时(AOT)的方式进行处理。该追踪图可以使用不同的输入执行或序列化。(1) 它产生标准化的操作符,这些操作符属于功能性的 ATen 操作集(包括任何用户指定的自定义操作符)。(2) 追踪图消除了所有 Python 控制流和数据结构(某些例外情况除外)。(3) 它记录了一组形状约束,以确保这种规范化和控制流消除对于未来的输入是可靠的。

正确性保证

在追踪过程中,export() 记录了用户程序和底层 PyTorch 操作符内核所作的形状相关假设。输出的 ExportedProgram 只有在这些假设成立时才被认为是有效的。

追踪对输入张量的形状(而非值)作出了假设。这些假设必须在图捕获时进行验证,以便export()成功。具体来说:

  • 输入张量的静态形状假设会自动进行验证,无需额外的努力。

  • 为了指定输入张量的动态形状,你需要使用 Dim() API 构建动态维度,并通过 dynamic_shapes 参数将这些维度与示例输入相关联。

如果任何假设无法验证,将会引发致命错误。当这种情况发生时,错误消息会提供一些修复建议来解决导致假设无法验证的问题。例如,export() 可能会建议对动态维度 dim0_x 的定义进行如下修改:该动态维度出现在与输入 x 相关的形状中,并且之前被定义为 Dim("dim0_x"):

dim = Dim("dim0_x", max=5)

这个例子表明生成的代码要求输入x的第一个维度必须小于或等于5才有效。你可以查看并直接将建议的修复方案复制到你的代码中,而无需修改export()调用中的dynamic_shapes参数。

参数
  • mod (Module) – 我们将追踪该模块的前向方法。

  • args (Tuple[Any, ...]) – 位置参数示例。

  • kwargs (Optional[Dict[str, Any]]) – 可选的示例关键字参数。

  • dynamic_shapes (可选[联合类型[字典[str, Any], 元组[Any], 列表[Any]]]) –

    一个可选参数,其类型应为以下两种之一:1) 从 f 的参数名到它们的动态形状规范的字典;2) 按照原始顺序指定每个输入的动态形状规范的元组。如果你要在关键字参数上指定动态性,则需要按照原始函数签名中定义的顺序传递这些参数。

    张量参数的动态形状可以以以下两种方式之一来指定:(1) 一个从动态维度索引到 Dim() 类型的字典,其中静态维度索引可选包含,并且当它们存在时应映射为 None;或者 (2) 由 Dim() 类型或 None 组成的元组/列表,其中 Dim() 类型表示动态维度,静态维度则用 None 表示。对于字典或包含张量的元组/列表参数,则通过递归地使用嵌套规格的映射或序列来指定。

  • strict (bool) – 当启用时(默认情况下),导出函数将通过 TorchDynamo 追踪程序,确保生成的图的有效性。否则,导出的程序不会验证图中的隐含假设,并可能导致原始模型和导出模型之间的行为差异。这在用户需要绕过 tracer 中的问题或逐步增强模型安全性时很有用。请注意,此选项不会影响最终的 IR 规范,并且无论传递什么值,模型都会以相同的方式进行序列化。警告:此选项是实验性的,请自行承担风险。

返回值

包含追踪调用的 ExportedProgram

返回类型

ExportedProgram

支持的输入/输出类型

接受的输入类型(argskwargs)和输出包括:

  • 基本类型包括 torch.Tensorintfloatboolstr

  • 数据类,但在使用前必须先通过调用register_dataclass()进行注册。

  • dictlisttuplenamedtupleOrderedDict组成的(嵌套)数据结构,包含了以上所有类型的元素。

torch.export.save(ep, f, *, extra_files=None, opset_version=None)[源代码]

警告

该项目正处于积极开发阶段,保存的文件在 PyTorch 新版本中可能会失效。

将一个ExportedProgram保存到类似文件的对象中。之后可以使用Python API torch.export.load进行加载。

参数
  • ep (ExportedProgram) - 需要保存的导出程序。

  • f (Union[str, os.PathLike, io.BytesIO) – 类文件对象(需要实现 write 和 flush 方法)或包含文件名的字符串。

  • extra_files (Optional[Dict[str, Any]]) – 文件名到内容的映射,这些内容将作为 f 的一部分进行存储。

  • opset_version (Optional[Dict[str, int]]) – 操作集名称到其版本的可选字典映射

示例:

import torch
import io

class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10

ep = torch.export.export(MyModule(), (torch.randn(5),))

# Save to file
torch.export.save(ep, 'exported_program.pt2')

# Save to io.BytesIO buffer
buffer = io.BytesIO()
torch.export.save(ep, buffer)

# Save with extra files
extra_files = {'foo.txt': b'bar'.decode('utf-8')}
torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files)
torch.export.load(f, *, extra_files=None, expected_opset_version=None)[源代码]

警告

该项目正处于积极开发阶段,保存的文件在 PyTorch 新版本中可能会失效。

加载一个之前使用torch.export.save保存的ExportedProgram

参数
  • ep (ExportedProgram) - 需要保存的导出程序。

  • f (Union[str, os.PathLike, io.BytesIO) – 类文件对象(需要实现 write 和 flush 方法)或包含文件名的字符串。

  • extra_files (Optional[Dict[str, Any]]) – 此映射中指定的额外文件将会被加载,其内容会被存储在提供的映射中。

  • expected_opset_version (Optional[Dict[str, int]]) – 操作集名称到预期操作集版本的映射表

返回值

一个ExportedProgram对象

返回类型

ExportedProgram

示例:

import torch
import io

# Load ExportedProgram from file
ep = torch.export.load('exported_program.pt2')

# Load ExportedProgram from io.BytesIO object
with open('exported_program.pt2', 'rb') as f:
    buffer = io.BytesIO(f.read())
buffer.seek(0)
ep = torch.export.load(buffer)

# Load with extra files.
extra_files = {'foo.txt': ''}  # values will be replaced with data
ep = torch.export.load('exported_program.pt2', extra_files=extra_files)
print(extra_files['foo.txt'])
print(ep(torch.randn(5)))
torch.export.register_dataclass(cls, *, serialized_type_name=None)[源代码]

将数据类注册为torch.export.export()的有效输入和输出类型。

参数
  • cls (Type[Any]) – 需要注册的数据类类型

  • serialized_type_name (Optional[str]) – 数据类的序列化名称。

  • this如果你要序列化包含 pytree TreeSpec 的内容,这是必需的) –

  • dataclass -

示例:

@dataclass
class InputDataClass:
    feature: torch.Tensor
    bias: int

class OutputDataClass:
    res: torch.Tensor

torch.export.register_dataclass(InputDataClass)
torch.export.register_dataclass(OutputDataClass)

def fn(o: InputDataClass) -> torch.Tensor:
    res = res=o.feature + o.bias
    return OutputDataClass(res=res)

ep = torch.export.export(fn, (InputDataClass(torch.ones(2, 2), 1), ))
print(ep)
torch.export.dynamic_shapes.Dim(name, *, min=None, max=None)[源代码]

Dim() 构建了一个类似于具有范围的命名符号整数的类型。它可以用来描述动态张量维度的各种可能值。需要注意的是,同一个张量的不同动态维度,或者不同张量之间的动态维度,可以使用相同的类型来描述。

参数
  • name (str) – 调试时使用的可读名称。

  • min (Optional[int]) – 给定符号的最小可能值(包含在内)

  • max (Optional[int]) – 给定符号的可能最大值(包括该值)

返回值

可用于张量动态形状规格的类型。

torch.export.dims(*names, min=None, max=None)[源代码]

用于创建多种 Dim() 类型的工具。

torch.export.dynamic_shapes.ShapesCollection[源代码]

动态形状的构建器,用于为输入中出现的张量分配动态形状规格。

示例:

args = ({“x”: tensor_x, “others”: [tensor_y, tensor_z]})

dim = torch.export.Dim(…)
dynamic_shapes = torch.export.ShapesCollection()
dynamic_shapes[tensor_x] = (dim, dim + 1, 8)
dynamic_shapes[tensor_y] = {0: dim * 2}
# 这等同于以下自动生成的代码:
# dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [{0: dim * 2}, None]}

torch.export(..., args, dynamic_shapes=dynamic_shapes)

dynamic_shapes(m, args, kwargs=None)[源代码]

生成动态形状。

torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes(msg, dynamic_shapes)[源代码]

处理导出动态形状的建议修复和/或自动动态形状。根据 ConstraintViolation 错误消息和原始动态形状,精炼给定的动态形状规范。

在大多数情况下,行为很简单:对于那些专门化或细化Dim范围的建议修复,或是提出派生关系的修复,新的动态形状规范将会进行相应的更新。

例如:建议的修复方法:

dim = Dim('dim', min=3, max=6) -> 这只是精炼了 dim 的范围
dim = 4 -> 这将其专门化为常量 4
dy = dx + 1 -> dy 被指定为一个独立的 dim,但实际上与 dx 存在这样的关系

然而,与派生维度相关的建议修复可能会更复杂。例如,如果为根维度提供了建议修复,那么新的派生维度值会根据根维度来评估。

例如:dx = Dim('dx')
dy = dx + 2
dynamic_shapes = {"x": (dx,), "y": (dy,)}

建议的解决方法:

dx = 4 # 特化会使得 dy 也进行特化,即 dy = 6 dx = Dim('dx', max=6) # 此时 dy 的最大值为 8

派生维度的建议修复还可以用来表示可除性约束。这涉及到创建不依赖于特定输入形状的新根维度。在这种情况下,这些根维度不会直接出现在新的规范中,而是会成为其他某个维度的根。

例如:建议的修复方法:

_dx = Dim(' _dx ', max=1024) # 这不会出现在返回结果中,但 dx 会 dx = 4 * _dx # 现在 dx 可以被 4 整除,并且最大值为 4096

返回类型

Union[Dict[str, Any], Tuple[Any, ...], List[Any]]

torch.export.Constraint

别名,等同于 Union[_Constraint, _DerivedConstraint]

classtorch.export.ExportedProgram(root, graph, graph_signature, state_dict, range_constraints, module_call_graph, example_inputs=None, constants=None, *, verifiers=None)[源代码]

这是由 export() 函数生成的程序包。它包含了表示张量计算的 torch.fx.Graph,一个包含所有提升参数和缓冲区张量值的状态字典,以及各种元数据。

你可以像调用最初被 export() 追踪的原始可调用对象一样,使用相同的调用约定来调用一个 ExportedProgram。

要对图进行变换,请使用.module属性访问一个torch.fx.GraphModule。然后可以使用FX变换重写图。之后,你可以简单地再次调用export()来构建一个正确的ExportedProgram

模块()[源代码]

返回一个自包含的GraphModule,其中包含了所有内联的参数和缓冲区。

返回类型

Module

缓冲区()[源代码]

返回一个遍历原始模块缓冲区的迭代器。

警告

此 API 是实验性的,且支持向下兼容。

返回类型

Iterator [ Tensor ]

named_buffers()[源代码]

返回一个原始模块缓冲区的迭代器,并同时提供缓冲区的名字和其内容。

警告

此 API 是实验性的,且支持向下兼容。

返回类型

Iterator[Tuple[str, Tensor]]

参数()[源代码]

返回原始模块参数的迭代器。

警告

此 API 是实验性的,且支持向下兼容。

返回类型

Iterator[Parameter]

named_parameters()[源代码]

返回原始模块参数的迭代器,同时提供参数的名称和其本身的值。

警告

此 API 是实验性的,且支持向下兼容。

返回类型

Iterator[Tuple[str, Parameter]]

run_decompositions(decomp_table=None, _preserve_ops=())[源代码]

对导出的程序进行一系列分解,并返回一个新的导出程序。默认情况下,我们会使用核心ATen分解来获取核心ATen操作集中的操作。

目前,我们不会分解联合图。

返回类型

ExportedProgram

torch.export.ExportBackwardSignature(gradients_to_parameters: Dict[str, str], gradients_to_user_inputs: Dict[str, str], loss_output: str)[源代码]
torch.export.ExportGraphSignature(input_specs, output_specs)[源代码]

ExportGraphSignature 定义了导出图的输入和输出签名,它是一个具有更强不变量保证的 fx.Graph。

导出图是功能性的,并不通过getattr节点访问图内的“状态”,如参数或缓冲区。相反,export()会将参数、缓冲区和常量张量作为输入从图中提取出来。同样,对任何缓冲区的修改也不会包含在图内,而是将其更新值作为导出图的附加输出。

所有输入和输出的顺序如下:

Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs]
Outputs = [*mutated_inputs, *flattened_user_outputs]
例如,如果导出以下模块:
class CustomModule(nn.Module):
    def __init__(self) -> None:
        super(CustomModule, self).__init__()

        # Define a parameter
        self.my_parameter = nn.Parameter(torch.tensor(2.0))

        # Define two buffers
        self.register_buffer('my_buffer1', torch.tensor(3.0))
        self.register_buffer('my_buffer2', torch.tensor(4.0))

    def forward(self, x1, x2):
        # Use the parameter, buffers, and both inputs in the forward method
        output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2

        # Mutate one of the buffers (e.g., increment it by 1)
        self.my_buffer2.add_(1.0) # In-place addition

        return output

生成的图表是:

graph():
    %arg0_1 := placeholder[target=arg0_1]
    %arg1_1 := placeholder[target=arg1_1]
    %arg2_1 := placeholder[target=arg2_1]
    %arg3_1 := placeholder[target=arg3_1]
    %arg4_1 := placeholder[target=arg4_1]
    %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {})
    %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {})
    %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {})
    %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {})
    %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {})
    return (add_tensor_2, add_tensor_1)

生成的ExportGraphSignature将是:

ExportGraphSignature(
    input_specs=[
        InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'),
        InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'),
        InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'),
        InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None),
        InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None)
    ],
    output_specs=[
        OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'),
        OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None)
    ]
)
torch.export.ModuleCallSignature(inputs:List[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], outputs:List[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], in_spec:torch.utils._pytree.TreeSpec, out_spec:torch.utils._pytree.TreeSpec)[源代码]
torch.export.ModuleCallEntry(fqn:str, signature:Optional[torch.export.exported_program.ModuleCallSignature]=None)[源代码]
torch.export.graph_signature.InputKind(value)[源代码]

一个枚举类型。

torch.export.graph_signature.InputSpec(kind:torch.export.graph_signature.InputKind, arg:Union[torch.export.graph_signature.TensorArgument,torch.export.graph_signature.SymIntArgument,torch.export.graph_signature.ConstantArgument,torch.export.graph_signature.CustomObjArgument,torch.export.graph_signature.TokenArgument], target:Optional[str], persistent:Optional[bool]=None)[源代码]
torch.export.graph_signature.OutputKind(value)[源代码]

一个枚举类型。

torch.export.graph_signature.OutputSpec(kind: torch.export.graph_signature.OutputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], target: Optional[str])[源代码]
torch.export.graph_signature.ExportGraphSignature(input_specs, output_specs)[源代码]

ExportGraphSignature 定义了导出图的输入和输出签名,它是一个具有更强不变量保证的 fx.Graph。

导出图是功能性的,并不通过getattr节点访问图内的“状态”,如参数或缓冲区。相反,export()会将参数、缓冲区和常量张量作为输入从图中提取出来。同样,对缓冲区的任何修改也不会包含在图内,而是将其更新值作为导出图的附加输出。

所有输入和输出的顺序如下:

Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs]
Outputs = [*mutated_inputs, *flattened_user_outputs]
例如,如果导出以下模块:
class CustomModule(nn.Module):
    def __init__(self) -> None:
        super(CustomModule, self).__init__()

        # Define a parameter
        self.my_parameter = nn.Parameter(torch.tensor(2.0))

        # Define two buffers
        self.register_buffer('my_buffer1', torch.tensor(3.0))
        self.register_buffer('my_buffer2', torch.tensor(4.0))

    def forward(self, x1, x2):
        # Use the parameter, buffers, and both inputs in the forward method
        output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2

        # Mutate one of the buffers (e.g., increment it by 1)
        self.my_buffer2.add_(1.0) # In-place addition

        return output

生成的图表是:

graph():
    %arg0_1 := placeholder[target=arg0_1]
    %arg1_1 := placeholder[target=arg1_1]
    %arg2_1 := placeholder[target=arg2_1]
    %arg3_1 := placeholder[target=arg3_1]
    %arg4_1 := placeholder[target=arg4_1]
    %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {})
    %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {})
    %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {})
    %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {})
    %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {})
    return (add_tensor_2, add_tensor_1)

生成的ExportGraphSignature将是:

ExportGraphSignature(
    input_specs=[
        InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'),
        InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'),
        InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'),
        InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None),
        InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None)
    ],
    output_specs=[
        OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'),
        OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None)
    ]
)
replace_all_uses(old, new)[源代码]

将签名中的旧名称全部替换为新名称。

get_replace_hook()[源代码]
torch.export.graph_signature.CustomObjArgument(name: str, class_fqn: str, fake_val: Optional[torch._library.fake_class_registry.FakeScriptObject] = None)[源代码]
torch.export.unflatten.FlatArgsAdapter[源代码]

将输入参数适配为符合 target_spec,通过使用 input_spec

abstractadapt(target_spec, input_spec, input_args)[源代码]

注意:此适配器可能会更改给定的 input_args_with_path

返回类型

`List[Any]`

torch.export.unflatten.InterpreterModule(graph)[源代码]

这是一个使用 torch.fx.Interpreter 进行执行的模块,而不用通常 GraphModule 中使用的代码生成。这种方式提供更详细的堆栈跟踪信息,使得调试变得更加简单。

torch.export.unflatten.unflatten(module, flat_args_adapter=None)[源代码]

将导出的程序展开,生成一个具有与原始即时执行模块相同层次结构的新模块。如果你尝试使用 torch.export 与其他期望模块层次结构而非 torch.export 通常生成的扁平图的系统一起工作,这会很有用。

注意

未展平模块的参数(args/kwargs)可能与 eager 模块不匹配,因此直接替换模块(例如 self.submod = new_mod)可能会失败。如果你需要替换一个模块,请设置 torch.export.export()preserve_module_call_signature 参数。

参数
  • module (ExportedProgram) - 要展开的 ExportedProgram。

  • flat_args_adapter (Optional[FlatArgsAdapter]) – 当输入的 TreeSpec 与导出模块不匹配时,用于适配扁平化参数。

返回值

一个 UnflattenedModule 实例,其模块层次结构与导出前的原始 eager 模块相同。

返回类型

UnflattenedModule

torch.export.passes.move_to_device_pass(ep, location)[源代码]

将导出的程序移到指定设备。

参数
  • ep (ExportedProgram) - 需要移动的导出程序。

  • location (Union[torch.device, str, Dict[str, str]]) – 导出程序要移动到的设备。如果是一个字符串,它会被解释为设备名称;如果是字典,则表示从现有设备到目标设备的映射。

返回值

已移动的导出程序。

返回类型

ExportedProgram

本页目录