torch.export
警告
此功能目前处于积极开发阶段的原型,未来可能会有重大变化。
概述
torch.export.export()
接受任意 Python 可调用对象(如 torch.nn.Module
、函数或方法),并以 Ahead-of-Time (AOT) 方式生成一个仅表示张量计算的追踪图。此追踪图可以随后使用不同的输入执行,或者进行序列化。
import torch from torch.export import export class Mod(torch.nn.Module): def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: a = torch.sin(x) b = torch.cos(y) return a + b example_args = (torch.randn(10, 10), torch.randn(10, 10)) exported_program: torch.export.ExportedProgram = export( Mod(), args=example_args ) print(exported_program)
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[10, 10], arg1_1: f32[10, 10]): # code: a = torch.sin(x) sin: f32[10, 10] = torch.ops.aten.sin.default(arg0_1); # code: b = torch.cos(y) cos: f32[10, 10] = torch.ops.aten.cos.default(arg1_1); # code: return a + b add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos); return (add,) Graph signature: ExportGraphSignature( parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1'], user_outputs=['add'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None, ) Range constraints: {}
torch.export
生成一个干净的中间表示(IR),并保持以下不变性。有关 IR 的更多信息,请参见此处规范。
-
正确性:它确保是对原始程序的准确表示,并保留了原始程序的调用约定。
-
标准化:图中不包含Python语义,来自原始程序的子模块被内联,形成了一个完全展平的计算图。
-
图属性:该图是纯函数式的,不包含任何具有副作用的操作(如变异或别名)。它不会修改任何中间值、参数或缓冲区。
-
元数据: 该图包含了在追踪过程中捕获的元数据,例如用户代码中的堆栈跟踪信息。
在底层,torch.export
使用了以下最新技术:
-
TorchDynamo (torch._dynamo) 是一个内部API,它使用CPython的Frame Evaluation API功能来安全地追踪PyTorch图。这提供了极大的改进的图捕获体验,并且为了完全追踪PyTorch代码所需的重写大大减少。
-
AOT Autograd 提供一个功能化的 PyTorch 图,并确保该图被分解或降级为 ATen 操作集。
-
Torch FX (torch.fx) 是图的底层表示,支持灵活的基于 Python 的变换。
现有框架
torch.compile()
也使用与 torch.export
相同的 PT2 堆栈,但有一些差异:
-
JIT 与 AOT:
torch.compile()
是一个 JIT 编译器,不用于在部署之外生成编译后的工件。 -
部分图捕获与完整图捕获: 当
torch.compile()
遇到无法追踪的模型部分时,它会“图中断”并回退到即时 Python 运行环境中执行程序。相比之下,torch.export
旨在获取 PyTorch 模型的完整图表示,因此当遇到不可追踪的部分时会报错。由于torch.export
生成的是一个与任何 Python 特性或运行环境完全分离的完整图,这个图可以在不同的环境和语言中保存、加载并执行。 -
可用性权衡: 由于
torch.compile()
能够在遇到无法追踪的代码时回退到 Python 运行时,因此它更加灵活。而torch.export
则需要用户提供更多信息或重写代码以使其可被追踪。
与torch.fx.symbolic_trace()
相比,torch.export
使用在 Python 字节码级别操作的 TorchDynamo 进行跟踪,这使得它可以追踪任意的 Python 构造,而不受 Python 操作符重载支持范围的限制。此外,torch.export
会详细记录张量元数据,因此基于张量形状等条件判断不会导致跟踪失败。总体而言,torch.export
预期可以适用于更多的用户程序,并生成更低级别的图(在 torch.ops.aten
操作符级别)。需要注意的是,用户仍然可以在使用 torch.export
之前采用torch.fx.symbolic_trace()
作为预处理步骤。
与torch.fx.symbolic_trace()
相比,torch.export
使用 TorchDynamo 在 Python 字节码级别进行跟踪。这使得它可以追踪任意的 Python 构造,而不受 Python 操作符重载支持范围的限制。此外,torch.export
会详细记录张量元数据,因此基于张量形状等条件判断不会导致跟踪失败。总体而言,torch.export
预期可以适用于更多的用户程序,并生成更低级别的图(在 torch.ops.aten
操作符级别)。需要注意的是,在使用 torch.export
之前,用户仍然可以采用torch.fx.symbolic_trace()
作为预处理步骤。
与torch.jit.script()
相比,torch.export
不会捕获Python的控制流或数据结构,但它支持比TorchScript更多的Python语言特性(因为它更容易全面覆盖Python字节码)。生成的图更简单,并且只包含直线控制流(除了显式的控制流操作符)。
与torch.jit.trace()
相比,torch.export
更为可靠:它能够追踪进行整数计算以确定大小的代码,并记录所有必要的条件,确保特定的跟踪对其他输入也有效。
导出 PyTorch 模型
示例
主要的入口是通过 torch.export.export()
,它接受一个可调用对象(如 torch.nn.Module
、函数或方法)和样本输入,并将计算图捕获到一个 torch.export.ExportedProgram
中。例如:
import torch from torch.export import export # Simple module for demonstration class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d( in_channels=3, out_channels=16, kernel_size=3, padding=1 ) self.relu = torch.nn.ReLU() self.maxpool = torch.nn.MaxPool2d(kernel_size=3) def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor: a = self.conv(x) a.add_(constant) return self.maxpool(self.relu(a)) example_args = (torch.randn(1, 3, 256, 256),) example_kwargs = {"constant": torch.ones(1, 16, 256, 256)} exported_program: torch.export.ExportedProgram = export( M(), args=example_args, kwargs=example_kwargs ) print(exported_program)
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256], arg3_1: f32[1, 16, 256, 256]): # code: a = self.conv(x) convolution: f32[1, 16, 256, 256] = torch.ops.aten.convolution.default( arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1 ); # code: a.add_(constant) add: f32[1, 16, 256, 256] = torch.ops.aten.add.Tensor(convolution, arg3_1); # code: return self.maxpool(self.relu(a)) relu: f32[1, 16, 256, 256] = torch.ops.aten.relu.default(add); max_pool2d_with_indices = torch.ops.aten.max_pool2d_with_indices.default( relu, [3, 3], [3, 3] ); getitem: f32[1, 16, 85, 85] = max_pool2d_with_indices[0]; return (getitem,) Graph signature: ExportGraphSignature( parameters=['L__self___conv.weight', 'L__self___conv.bias'], buffers=[], user_inputs=['arg2_1', 'arg3_1'], user_outputs=['getitem'], inputs_to_parameters={ 'arg0_1': 'L__self___conv.weight', 'arg1_1': 'L__self___conv.bias', }, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None, ) Range constraints: {}
查看ExportedProgram
,我们可以注意到以下几点:
-
torch.fx.Graph
包含了原始程序的计算图,并记录下原始代码以方便调试。 -
该图仅包含在此处找到的
torch.ops.aten
操作符和自定义操作符,并且是完全功能性的,不包括任何如torch.add_
之类的原地操作符。 -
参数(包括卷积的权重和偏置)被提升为图的输入,因此图中不再包含之前由
torch.fx.symbolic_trace()
生成的get_attr
节点。 -
The
torch.export.ExportGraphSignature
定义了输入和输出的签名,并指定了哪些输入是参数。 -
图中每个节点生成的张量的形状和数据类型会被记录下来。例如,
convolution
节点将产生一个数据类型为torch.float32
且形状为 (1, 16, 256, 256) 的张量。
非严格模式下的导出
在 PyTorch 2.3 中,我们引入了一种新的追踪模式,称为 非严格模式。该模式仍在完善中,因此如果你遇到任何问题,请通过 Github 提交问题,并加上“oncall: export”标签。
在非严格模式下,我们使用Python解释器来跟踪程序的执行。你的代码将完全像在急切模式下一样运行;唯一的区别是所有的Tensor对象都将被替换为ProxyTensors,这些ProxyTensors会记录它们的所有操作到一个图中。
在严格模式下(当前默认设置),我们首先使用TorchDynamo(一个字节码分析引擎)对程序进行跟踪。TorchDynamo不会执行你的Python代码,而是对其进行符号分析,并基于结果构建图。这种分析使torch.export能够提供更强的安全性保证,但并非所有的Python代码都受支持。
一个可能需要使用非严格模式的例子是,当你遇到 TorchDynamo 不支持的功能并且难以解决时,而你又知道 Python 代码并非完全用于计算。例如:
import contextlib import torch class ContextManager(): def __init__(self): self.count = 0 def __enter__(self): self.count += 1 def __exit__(self, exc_type, exc_value, traceback): self.count -= 1 class M(torch.nn.Module): def forward(self, x): with ContextManager(): return x.sin() + x.cos() export(M(), (torch.ones(3, 3),), strict=False) # Non-strict traces successfully export(M(), (torch.ones(3, 3),)) # Strict mode fails with torch._dynamo.exc.Unsupported: ContextManager
在这个例子中,第一次使用非严格模式(通过strict=False
标志)成功执行了追踪,而第二次使用默认的严格模式则失败了,因为TorchDynamo无法支持上下文管理器。一种解决方法是重写代码(参见torch.export 的限制),但由于上下文管理器不影响模型中的张量计算,我们可以选择接受非严格模式的结果。
表达动态性
默认情况下,torch.export
假设所有输入形状都是静态的并为此导出程序进行专门化。然而,某些维度(如批量维度)可以是动态变化的,并且每次运行时都可能不同。必须使用 torch.export.Dim()
API 创建这些动态维度并通过torch.export.export()
的dynamic_shapes
参数传递它们。例如:
import torch from torch.export import Dim, export class M(torch.nn.Module): def __init__(self): super().__init__() self.branch1 = torch.nn.Sequential( torch.nn.Linear(64, 32), torch.nn.ReLU() ) self.branch2 = torch.nn.Sequential( torch.nn.Linear(128, 64), torch.nn.ReLU() ) self.buffer = torch.ones(32) def forward(self, x1, x2): out1 = self.branch1(x1) out2 = self.branch2(x2) return (out1 + self.buffer, out2) example_args = (torch.randn(32, 64), torch.randn(32, 128)) # Create a dynamic batch size batch = Dim("batch") # Specify that the first dimension of each input is that batch size dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}} exported_program: torch.export.ExportedProgram = export( M(), args=example_args, dynamic_shapes=dynamic_shapes ) print(exported_program)
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[32, 64], arg1_1: f32[32], arg2_1: f32[64, 128], arg3_1: f32[64], arg4_1: f32[32], arg5_1: f32[s0, 64], arg6_1: f32[s0, 128]): # code: out1 = self.branch1(x1) permute: f32[64, 32] = torch.ops.aten.permute.default(arg0_1, [1, 0]); addmm: f32[s0, 32] = torch.ops.aten.addmm.default(arg1_1, arg5_1, permute); relu: f32[s0, 32] = torch.ops.aten.relu.default(addmm); # code: out2 = self.branch2(x2) permute_1: f32[128, 64] = torch.ops.aten.permute.default(arg2_1, [1, 0]); addmm_1: f32[s0, 64] = torch.ops.aten.addmm.default(arg3_1, arg6_1, permute_1); relu_1: f32[s0, 64] = torch.ops.aten.relu.default(addmm_1); addmm_1 = None # code: return (out1 + self.buffer, out2) add: f32[s0, 32] = torch.ops.aten.add.Tensor(relu, arg4_1); return (add, relu_1) Graph signature: ExportGraphSignature( parameters=[ 'branch1.0.weight', 'branch1.0.bias', 'branch2.0.weight', 'branch2.0.bias', ], buffers=['L__self___buffer'], user_inputs=['arg5_1', 'arg6_1'], user_outputs=['add', 'relu_1'], inputs_to_parameters={ 'arg0_1': 'branch1.0.weight', 'arg1_1': 'branch1.0.bias', 'arg2_1': 'branch2.0.weight', 'arg3_1': 'branch2.0.bias', }, inputs_to_buffers={'arg4_1': 'L__self___buffer'}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None, ) Range constraints: {s0: RangeConstraint(min_val=2, max_val=9223372036854775806)}
其他注意事项:
-
通过
torch.export.Dim()
API 和dynamic_shapes
参数,我们指定了每个输入的第一个维度为动态。查看输入arg5_1
和arg6_1
,它们的符号形状分别为 (s0, 64) 和 (s0, 128),而不是我们作为示例传递的 (32, 64) 和 (32, 128) 形状的张量。其中s0
是一个表示该维度可以取一系列值的符号。 -
exported_program.range_constraints
描述了图中每个符号的取值范围。在这种情况下,我们看到s0
的范围是 [2, inf]。由于一些难以在这里解释的技术原因,它们被假定不为 0 或 1。这不是一个错误,并不一定意味着导出的程序在维度为 0 或 1 时无法正常工作。有关此主题的深入讨论,请参阅The 0/1 Specialization Problem。
我们还可以指定输入形状之间的更复杂的关系,例如,两个形状可能相差一个维度,一个形状可能是另一个的两倍,或者一个形状的维度是偶数。例如:
class M(torch.nn.Module): def forward(self, x, y): return x + y[1:] x, y = torch.randn(5), torch.randn(6) dimx = torch.export.Dim("dimx", min=3, max=6) dimy = dimx + 1 exported_program = torch.export.export( M(), (x, y), dynamic_shapes=({0: dimx}, {0: dimy}), ) print(exported_program)
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: "f32[s0]", arg1_1: "f32[s0 + 1]"): # code: return x + y[1:] slice_1: "f32[s0]" = torch.ops.aten.slice.Tensor(arg1_1, 0, 1, 9223372036854775807); arg1_1 = None add: "f32[s0]" = torch.ops.aten.add.Tensor(arg0_1, slice_1); arg0_1 = slice_1 = None return (add,) Graph signature: ExportGraphSignature( input_specs=[ InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None) ], output_specs=[ OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)] ) Range constraints: {s0: ValueRanges(lower=3, upper=6, is_bool=False), s0 + 1: ValueRanges(lower=4, upper=7, is_bool=False)}
需要注意的几点:
-
通过为第一个输入指定
{0: dimx}
,我们看到其形状现在是动态的,即[s0]
。接着,通过为第二个输入指定{0: dimy}
,我们发现它的形状也是动态的。然而,由于表示了dimy = dimx + 1
的关系,arg1_1
不再使用新的符号,而是与arg0_1
中的s0
相同。因此,关系dimy = dimx + 1
通过s0 + 1
来表示。 -
查看范围约束,我们看到
s0
的初始范围是[3, 6],并且可以得出s0 + 1
的解算范围为[4, 7]。
序列化
要保存ExportedProgram
,用户可以使用torch.export.save()
和torch.export.load()
API。通常会使用.pt2
文件扩展名来保存ExportedProgram
。
示例:
import torch import io class MyModule(torch.nn.Module): def forward(self, x): return x + 10 exported_program = torch.export.export(MyModule(), torch.randn(5)) torch.export.save(exported_program, 'exported_program.pt2') saved_exported_program = torch.export.load('exported_program.pt2')
专业领域
理解 torch.export
行为的关键在于区分 静态 值和 动态 值。
一个动态值是指在每次运行时可能会发生变化的值。这些值的行为就像传递给 Python 函数的正常参数一样——你可以为参数传入不同的值,并期望函数能够正确处理。Tensor 数据被视为动态值。
一个静态值是在导出时确定的,并且在导出程序的不同执行间不会发生变化。当在跟踪过程中遇到这个值时,导出器会将其作为常量处理,并直接嵌入到图中。
当执行一个操作(例如 x + y
)且所有输入都是静态的,该操作的结果会直接硬编码到图中,并且这个操作不会显示出来(即它会被常量折叠)。
当一个值被硬编码到图中时,我们说该图已经针对这个值进行了专门化。
以下值为静态:
输入 tensor 形状
默认情况下,torch.export
会根据输入张量的形状来跟踪程序,除非通过 dynamic_shapes
参数将维度指定为动态。这意味着如果存在依赖于形状的控制流,torch.export
将针对给定样本输入所选择的分支进行专业化处理。例如:
import torch from torch.export import export class Mod(torch.nn.Module): def forward(self, x): if x.shape[0] > 5: return x + 1 else: return x - 1 example_inputs = (torch.rand(10, 2),) exported_program = export(Mod(), example_inputs) print(exported_program)
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[10, 2]): add: f32[10, 2] = torch.ops.aten.add.Tensor(arg0_1, 1); return (add,)
条件表达式 (x.shape[0] > 5
) 不出现在 ExportedProgram
中,因为示例输入具有静态形状 (10, 2)。由于 torch.export
根据输入的静态形状进行特化处理,else 分支 (x - 1
) 永远不会被执行。为了在跟踪图中保留基于张量形状的动态分支行为,需要使用 torch.export.Dim()
来指定输入张量 (x.shape[0]
) 的维度为动态,并且源代码需要重写。
请注意,作为模块状态一部分的张量(如参数和缓冲区)总是具有静态形状。
Python基本数据类型
torch.export
专门处理 Python 中的原始类型,例如 int
、float
、bool
和 str
。此外,还有动态变体如 SymInt
、SymFloat
和 SymBool
。
例如:
import torch from torch.export import export class Mod(torch.nn.Module): def forward(self, x: torch.Tensor, const: int, times: int): for i in range(times): x = x + const return x example_inputs = (torch.rand(2, 2), 1, 3) exported_program = export(Mod(), example_inputs) print(exported_program)
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[2, 2], arg1_1, arg2_1): add: f32[2, 2] = torch.ops.aten.add.Tensor(arg0_1, 1); add_1: f32[2, 2] = torch.ops.aten.add.Tensor(add, 1); add_2: f32[2, 2] = torch.ops.aten.add.Tensor(add_1, 1); return (add_2,)
由于整数是专门化的,torch.ops.aten.add.Tensor
操作都使用硬编码的常量 1
进行计算,而不是用户传递的参数 arg1_1
。如果在运行时传入不同的值(例如 2),而导出时使用的值是 1,则会导致错误。此外,在 for
循环中使用的 times
迭代器也通过重复的三个 torch.ops.aten.add.Tensor
调用在图中进行了“内联”,并且输入参数 arg2_1
从未被使用。
Python容器
Python 的容器(如 List
、Dict
、NamedTuple
等)被认为是具有静态结构的。
torch.export 的局限性
图表中断
由于 torch.export
是一个一次性过程,用于从 PyTorch 程序中捕获计算图,它可能会遇到无法追踪的程序部分,因为几乎不可能支持所有 PyTorch 和 Python 特性的跟踪。在 torch.compile
的情况下,不支持的操作会导致“图中断”,并以默认的 Python 评估方式运行该操作。相比之下,torch.export
要求用户提供额外的信息或重写部分代码使其可追踪。由于跟踪基于 TorchDynamo,在字节码级别进行评估,因此与之前的跟踪框架相比,所需的重写将显著减少。
当遇到图中断时,ExportDB 是一个很好的资源,可以了解哪些程序是受支持或不受支持的,并提供如何重写程序以使其可追踪的方法。
一个解决这个问题的方法是使用非严格导出
数据/形状相关的控制流
在形状未被专门化的情况下,数据依赖的控制流(如if x.shape[0] > 2
)可能会遇到图中断问题,因为追踪编译器无法处理组合爆炸数量的路径而不会生成代码。在这种情况下,用户需要使用特殊的控制流操作符重写他们的代码。目前我们支持torch.cond来表达类似if-else的控制流(更多功能即将推出!)。
操作符缺少假函数、元数据和抽象内核
在进行追踪时,所有操作符都需要一个 FakeTensor 内核(也称为元内核或抽象实现),以确定该操作符的输入和输出形状。
请参见torch.library.register_fake()
以获取更多详情。
如果您的模型使用了还未实现 FakeTensor 内核的 ATen 操作符,请在遇到这种情况时提交一个问题。
了解更多
导出用户的相关链接
PyTorch开发者的深入探索
API参考
- torch.export.export(mod, args, kwargs=None, *, dynamic_shapes=None, strict=True, preserve_module_call_signature=())[源代码]
-
export()
接受任意的 Python 可调用对象(如 nn.Module、函数或方法)以及示例输入,并生成一个仅表示张量计算的追踪图,以编译时(AOT)的方式进行处理。该追踪图可以使用不同的输入执行或序列化。(1) 它产生标准化的操作符,这些操作符属于功能性的 ATen 操作集(包括任何用户指定的自定义操作符)。(2) 追踪图消除了所有 Python 控制流和数据结构(某些例外情况除外)。(3) 它记录了一组形状约束,以确保这种规范化和控制流消除对于未来的输入是可靠的。正确性保证
在追踪过程中,
export()
记录了用户程序和底层 PyTorch 操作符内核所作的形状相关假设。输出的ExportedProgram
只有在这些假设成立时才被认为是有效的。追踪对输入张量的形状(而非值)作出了假设。这些假设必须在图捕获时进行验证,以便
export()
成功。具体来说:- 输入张量的静态形状假设会自动进行验证,无需额外的努力。
-
为了指定输入张量的动态形状,你需要使用
Dim()
API 构建动态维度,并通过dynamic_shapes
参数将这些维度与示例输入相关联。
如果任何假设无法验证,将会引发致命错误。当这种情况发生时,错误消息会提供一些修复建议来解决导致假设无法验证的问题。例如,
export()
可能会建议对动态维度dim0_x
的定义进行如下修改:该动态维度出现在与输入x
相关的形状中,并且之前被定义为Dim("dim0_x")
:dim = Dim("dim0_x", max=5)
这个例子表明生成的代码要求输入
x
的第一个维度必须小于或等于5才有效。你可以查看并直接将建议的修复方案复制到你的代码中,而无需修改export()
调用中的dynamic_shapes
参数。- 参数
-
-
mod (Module) – 我们将追踪该模块的前向方法。
-
dynamic_shapes (可选[联合类型[字典[str, Any], 元组[Any], 列表[Any]]]) –
一个可选参数,其类型应为以下两种之一:1) 从
f
的参数名到它们的动态形状规范的字典;2) 按照原始顺序指定每个输入的动态形状规范的元组。如果你要在关键字参数上指定动态性,则需要按照原始函数签名中定义的顺序传递这些参数。张量参数的动态形状可以以以下两种方式之一来指定:(1) 一个从动态维度索引到
Dim()
类型的字典,其中静态维度索引可选包含,并且当它们存在时应映射为 None;或者 (2) 由Dim()
类型或 None 组成的元组/列表,其中Dim()
类型表示动态维度,静态维度则用 None 表示。对于字典或包含张量的元组/列表参数,则通过递归地使用嵌套规格的映射或序列来指定。 -
strict (bool) – 当启用时(默认情况下),导出函数将通过 TorchDynamo 追踪程序,确保生成的图的有效性。否则,导出的程序不会验证图中的隐含假设,并可能导致原始模型和导出模型之间的行为差异。这在用户需要绕过 tracer 中的问题或逐步增强模型安全性时很有用。请注意,此选项不会影响最终的 IR 规范,并且无论传递什么值,模型都会以相同的方式进行序列化。警告:此选项是实验性的,请自行承担风险。
-
- 返回值
-
包含追踪调用的
ExportedProgram
。 - 返回类型
支持的输入/输出类型
接受的输入类型(
args
和kwargs
)和输出包括:-
基本类型包括
torch.Tensor
、int
、float
、bool
和str
。 -
数据类,但在使用前必须先通过调用
register_dataclass()
进行注册。 -
由
dict
、list
、tuple
、namedtuple
和OrderedDict
组成的(嵌套)数据结构,包含了以上所有类型的元素。
- torch.export.save(ep, f, *, extra_files=None, opset_version=None)[源代码]
-
警告
该项目正处于积极开发阶段,保存的文件在 PyTorch 新版本中可能会失效。
将一个
ExportedProgram
保存到类似文件的对象中。之后可以使用Python APItorch.export.load
进行加载。- 参数
-
-
ep (ExportedProgram) - 需要保存的导出程序。
-
f (Union[str, os.PathLike, io.BytesIO) – 类文件对象(需要实现 write 和 flush 方法)或包含文件名的字符串。
-
extra_files (Optional[Dict[str, Any]]) – 文件名到内容的映射,这些内容将作为 f 的一部分进行存储。
-
opset_version (Optional[Dict[str, int]]) – 操作集名称到其版本的可选字典映射
-
示例:
import torch import io class MyModule(torch.nn.Module): def forward(self, x): return x + 10 ep = torch.export.export(MyModule(), (torch.randn(5),)) # Save to file torch.export.save(ep, 'exported_program.pt2') # Save to io.BytesIO buffer buffer = io.BytesIO() torch.export.save(ep, buffer) # Save with extra files extra_files = {'foo.txt': b'bar'.decode('utf-8')} torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files)
- torch.export.load(f, *, extra_files=None, expected_opset_version=None)[源代码]
-
警告
该项目正处于积极开发阶段,保存的文件在 PyTorch 新版本中可能会失效。
加载一个之前使用
torch.export.save
保存的ExportedProgram
。- 参数
-
-
ep (ExportedProgram) - 需要保存的导出程序。
-
f (Union[str, os.PathLike, io.BytesIO) – 类文件对象(需要实现 write 和 flush 方法)或包含文件名的字符串。
-
extra_files (Optional[Dict[str, Any]]) – 此映射中指定的额外文件将会被加载,其内容会被存储在提供的映射中。
-
expected_opset_version (Optional[Dict[str, int]]) – 操作集名称到预期操作集版本的映射表
-
- 返回值
-
一个
ExportedProgram
对象 - 返回类型
示例:
import torch import io # Load ExportedProgram from file ep = torch.export.load('exported_program.pt2') # Load ExportedProgram from io.BytesIO object with open('exported_program.pt2', 'rb') as f: buffer = io.BytesIO(f.read()) buffer.seek(0) ep = torch.export.load(buffer) # Load with extra files. extra_files = {'foo.txt': ''} # values will be replaced with data ep = torch.export.load('exported_program.pt2', extra_files=extra_files) print(extra_files['foo.txt']) print(ep(torch.randn(5)))
- torch.export.register_dataclass(cls, *, serialized_type_name=None)[源代码]
-
将数据类注册为
torch.export.export()
的有效输入和输出类型。- 参数
示例:
@dataclass class InputDataClass: feature: torch.Tensor bias: int class OutputDataClass: res: torch.Tensor torch.export.register_dataclass(InputDataClass) torch.export.register_dataclass(OutputDataClass) def fn(o: InputDataClass) -> torch.Tensor: res = res=o.feature + o.bias return OutputDataClass(res=res) ep = torch.export.export(fn, (InputDataClass(torch.ones(2, 2), 1), )) print(ep)
- torch.export.dynamic_shapes.Dim(name, *, min=None, max=None)[源代码]
-
Dim()
构建了一个类似于具有范围的命名符号整数的类型。它可以用来描述动态张量维度的各种可能值。需要注意的是,同一个张量的不同动态维度,或者不同张量之间的动态维度,可以使用相同的类型来描述。
- torch.export.dims(*names, min=None, max=None)[源代码]
-
用于创建多种
Dim()
类型的工具。
- 类torch.export.dynamic_shapes.ShapesCollection[源代码]
-
动态形状的构建器,用于为输入中出现的张量分配动态形状规格。
- 示例:
-
args = ({“x”: tensor_x, “others”: [tensor_y, tensor_z]})
dim = torch.export.Dim(…)
dynamic_shapes = torch.export.ShapesCollection()
dynamic_shapes[tensor_x] = (dim, dim + 1, 8)
dynamic_shapes[tensor_y] = {0: dim * 2}
# 这等同于以下自动生成的代码:
# dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [{0: dim * 2}, None]}torch.export(..., args, dynamic_shapes=dynamic_shapes)
- dynamic_shapes(m, args, kwargs=None)[源代码]
-
生成动态形状。
- torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes(msg, dynamic_shapes)[源代码]
-
处理导出动态形状的建议修复和/或自动动态形状。根据 ConstraintViolation 错误消息和原始动态形状,精炼给定的动态形状规范。
在大多数情况下,行为很简单:对于那些专门化或细化Dim范围的建议修复,或是提出派生关系的修复,新的动态形状规范将会进行相应的更新。
例如:建议的修复方法:
dim = Dim('dim', min=3, max=6) -> 这只是精炼了 dim 的范围
dim = 4 -> 这将其专门化为常量 4
dy = dx + 1 -> dy 被指定为一个独立的 dim,但实际上与 dx 存在这样的关系然而,与派生维度相关的建议修复可能会更复杂。例如,如果为根维度提供了建议修复,那么新的派生维度值会根据根维度来评估。
例如:dx = Dim('dx')
dy = dx + 2
dynamic_shapes = {"x": (dx,), "y": (dy,)}建议的解决方法:
dx = 4 # 特化会使得 dy 也进行特化,即 dy = 6 dx = Dim('dx', max=6) # 此时 dy 的最大值为 8
派生维度的建议修复还可以用来表示可除性约束。这涉及到创建不依赖于特定输入形状的新根维度。在这种情况下,这些根维度不会直接出现在新的规范中,而是会成为其他某个维度的根。
例如:建议的修复方法:
_dx = Dim(' _dx ', max=1024) # 这不会出现在返回结果中,但 dx 会 dx = 4 * _dx # 现在 dx 可以被 4 整除,并且最大值为 4096
- torch.export.Constraint
-
别名,等同于
Union
[_Constraint
,_DerivedConstraint
]
- classtorch.export.ExportedProgram(root, graph, graph_signature, state_dict, range_constraints, module_call_graph, example_inputs=None, constants=None, *, verifiers=None)[源代码]
-
这是由
export()
函数生成的程序包。它包含了表示张量计算的torch.fx.Graph
,一个包含所有提升参数和缓冲区张量值的状态字典,以及各种元数据。你可以像调用最初被
export()
追踪的原始可调用对象一样,使用相同的调用约定来调用一个 ExportedProgram。要对图进行变换,请使用
.module
属性访问一个torch.fx.GraphModule
。然后可以使用FX变换重写图。之后,你可以简单地再次调用export()
来构建一个正确的ExportedProgram
。- named_buffers()[源代码]
-
返回一个原始模块缓冲区的迭代器,并同时提供缓冲区的名字和其内容。
警告
此 API 是实验性的,且不支持向下兼容。
- 参数()[源代码]
-
返回原始模块参数的迭代器。
警告
此 API 是实验性的,且不支持向下兼容。
- 返回类型
-
Iterator[Parameter]
- named_parameters()[源代码]
-
返回原始模块参数的迭代器,同时提供参数的名称和其本身的值。
警告
此 API 是实验性的,且不支持向下兼容。
- 类torch.export.ExportBackwardSignature(gradients_to_parameters: Dict[str, str], gradients_to_user_inputs: Dict[str, str], loss_output: str)[源代码]
-
- 类torch.export.ExportGraphSignature(input_specs, output_specs)[源代码]
-
ExportGraphSignature
定义了导出图的输入和输出签名,它是一个具有更强不变量保证的 fx.Graph。导出图是功能性的,并不通过
getattr
节点访问图内的“状态”,如参数或缓冲区。相反,export()
会将参数、缓冲区和常量张量作为输入从图中提取出来。同样,对任何缓冲区的修改也不会包含在图内,而是将其更新值作为导出图的附加输出。所有输入和输出的顺序如下:
Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] Outputs = [*mutated_inputs, *flattened_user_outputs]
例如,如果导出以下模块:class CustomModule(nn.Module): def __init__(self) -> None: super(CustomModule, self).__init__() # Define a parameter self.my_parameter = nn.Parameter(torch.tensor(2.0)) # Define two buffers self.register_buffer('my_buffer1', torch.tensor(3.0)) self.register_buffer('my_buffer2', torch.tensor(4.0)) def forward(self, x1, x2): # Use the parameter, buffers, and both inputs in the forward method output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2 # Mutate one of the buffers (e.g., increment it by 1) self.my_buffer2.add_(1.0) # In-place addition return output
生成的图表是:
graph(): %arg0_1 := placeholder[target=arg0_1] %arg1_1 := placeholder[target=arg1_1] %arg2_1 := placeholder[target=arg2_1] %arg3_1 := placeholder[target=arg3_1] %arg4_1 := placeholder[target=arg4_1] %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {}) %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {}) %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {}) %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {}) %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {}) return (add_tensor_2, add_tensor_1)
生成的ExportGraphSignature将是:
ExportGraphSignature( input_specs=[ InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None) ], output_specs=[ OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None) ] )
- 类torch.export.ModuleCallSignature(inputs:List[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], outputs:List[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], in_spec:torch.utils._pytree.TreeSpec, out_spec:torch.utils._pytree.TreeSpec)[源代码]
-
- 类torch.export.ModuleCallEntry(fqn:str, signature:Optional[torch.export.exported_program.ModuleCallSignature]=None)[源代码]
-
- 类torch.export.graph_signature.InputKind(value)[源代码]
-
一个枚举类型。
- 类torch.export.graph_signature.InputSpec(kind:torch.export.graph_signature.InputKind, arg:Union[torch.export.graph_signature.TensorArgument,torch.export.graph_signature.SymIntArgument,torch.export.graph_signature.ConstantArgument,torch.export.graph_signature.CustomObjArgument,torch.export.graph_signature.TokenArgument], target:Optional[str], persistent:Optional[bool]=None)[源代码]
-
- 类torch.export.graph_signature.OutputKind(value)[源代码]
-
一个枚举类型。
- 类torch.export.graph_signature.OutputSpec(kind: torch.export.graph_signature.OutputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], target: Optional[str])[源代码]
-
- 类torch.export.graph_signature.ExportGraphSignature(input_specs, output_specs)[源代码]
-
ExportGraphSignature
定义了导出图的输入和输出签名,它是一个具有更强不变量保证的 fx.Graph。导出图是功能性的,并不通过
getattr
节点访问图内的“状态”,如参数或缓冲区。相反,export()
会将参数、缓冲区和常量张量作为输入从图中提取出来。同样,对缓冲区的任何修改也不会包含在图内,而是将其更新值作为导出图的附加输出。所有输入和输出的顺序如下:
Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] Outputs = [*mutated_inputs, *flattened_user_outputs]
例如,如果导出以下模块:class CustomModule(nn.Module): def __init__(self) -> None: super(CustomModule, self).__init__() # Define a parameter self.my_parameter = nn.Parameter(torch.tensor(2.0)) # Define two buffers self.register_buffer('my_buffer1', torch.tensor(3.0)) self.register_buffer('my_buffer2', torch.tensor(4.0)) def forward(self, x1, x2): # Use the parameter, buffers, and both inputs in the forward method output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2 # Mutate one of the buffers (e.g., increment it by 1) self.my_buffer2.add_(1.0) # In-place addition return output
生成的图表是:
graph(): %arg0_1 := placeholder[target=arg0_1] %arg1_1 := placeholder[target=arg1_1] %arg2_1 := placeholder[target=arg2_1] %arg3_1 := placeholder[target=arg3_1] %arg4_1 := placeholder[target=arg4_1] %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {}) %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {}) %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {}) %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {}) %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {}) return (add_tensor_2, add_tensor_1)
生成的ExportGraphSignature将是:
ExportGraphSignature( input_specs=[ InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None) ], output_specs=[ OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None) ] )
- replace_all_uses(old, new)[源代码]
-
将签名中的旧名称全部替换为新名称。
- get_replace_hook()[源代码]
- 类torch.export.graph_signature.CustomObjArgument(name: str, class_fqn: str, fake_val: Optional[torch._library.fake_class_registry.FakeScriptObject] = None)[源代码]
-
- 类torch.export.unflatten.FlatArgsAdapter[源代码]
-
将输入参数适配为符合
target_spec
,通过使用input_spec
。- abstractadapt(target_spec, input_spec, input_args)[源代码]
-
注意:此适配器可能会更改给定的
input_args_with_path
。- 返回类型
-
`List[Any]`
- 类torch.export.unflatten.InterpreterModule(graph)[源代码]
-
这是一个使用 torch.fx.Interpreter 进行执行的模块,而不用通常 GraphModule 中使用的代码生成。这种方式提供更详细的堆栈跟踪信息,使得调试变得更加简单。
- torch.export.unflatten.unflatten(module, flat_args_adapter=None)[源代码]
-
将导出的程序展开,生成一个具有与原始即时执行模块相同层次结构的新模块。如果你尝试使用
torch.export
与其他期望模块层次结构而非torch.export
通常生成的扁平图的系统一起工作,这会很有用。注意
未展平模块的参数(args/kwargs)可能与 eager 模块不匹配,因此直接替换模块(例如
self.submod = new_mod
)可能会失败。如果你需要替换一个模块,请设置torch.export.export()
的preserve_module_call_signature
参数。- 参数
-
-
module (ExportedProgram) - 要展开的 ExportedProgram。
-
flat_args_adapter (Optional[FlatArgsAdapter]) – 当输入的 TreeSpec 与导出模块不匹配时,用于适配扁平化参数。
-
- 返回值
-
一个
UnflattenedModule
实例,其模块层次结构与导出前的原始 eager 模块相同。 - 返回类型
-
UnflattenedModule
- torch.export.passes.move_to_device_pass(ep, location)[源代码]
-
将导出的程序移到指定设备。
- 参数
-
-
ep (ExportedProgram) - 需要移动的导出程序。
-
location (Union[torch.device, str, Dict[str, str]]) – 导出程序要移动到的设备。如果是一个字符串,它会被解释为设备名称;如果是字典,则表示从现有设备到目标设备的映射。
-
- 返回值
-
已移动的导出程序。
- 返回类型