基于TorchDynamo的ONNX exporter
警告
TorchDynamo 的 ONNX 导出器是一款迅速发展的 beta 技术。
概述
ONNX 导出器利用 TorchDynamo 引擎,通过挂钩到 Python 的帧评估 API,将字节码动态地转换为 FX 图。随后,在将其最终翻译成 ONNX 图之前,会对生成的 FX 图进行优化。
这种方法的主要优势在于,它通过字节码分析来捕获FX图,从而保留了模型的动态特性,而不会像传统静态追踪技术那样丢失这些特性。
导出器被设计为模块化和可扩展,包含以下几个组成部分:
ONNX 导出器:
Exporter
主类,负责协调导出过程。ONNX 导出选项:
ExportOptions
提供了一组用于控制导出过程的选项。ONNX 注册表:
OnnxRegistry
是用于存储 ONNX 操作符和函数的注册表。FX 图提取器:
FXGraphExtractor
用于从 PyTorch 模型中提取 FX 图。模拟模式:
ONNXFakeContext
是一个上下文管理器,用于在大规模模型中启用模拟模式。ONNX程序:
ONNXProgram
是导出器生成的结果,包含了导出的ONNX图以及相关的诊断信息。ONNX 诊断选项:
DiagnosticOptions
提供了一系列选项来控制导出器生成的诊断信息。
依赖项
ONNX 导出器依赖于一些额外的 Python 包:
它们可以使用 pip 进行安装:
pipinstall--upgradeonnxonnxscript
ONNX Runtime 可以用于在多种处理器上运行模型。
一个简单的例子
以下是一个使用导出器API的演示,示例中采用了一个简单的多层感知器(MLP):
import torch import torch.nn as nn class MLPModel(nn.Module): def __init__(self): super().__init__() self.fc0 = nn.Linear(8, 8, bias=True) self.fc1 = nn.Linear(8, 4, bias=True) self.fc2 = nn.Linear(4, 2, bias=True) self.fc3 = nn.Linear(2, 2, bias=True) def forward(self, tensor_x: torch.Tensor): tensor_x = self.fc0(tensor_x) tensor_x = torch.sigmoid(tensor_x) tensor_x = self.fc1(tensor_x) tensor_x = torch.sigmoid(tensor_x) tensor_x = self.fc2(tensor_x) tensor_x = torch.sigmoid(tensor_x) output = self.fc3(tensor_x) return output model = MLPModel() tensor_x = torch.rand((97, 8), dtype=torch.float32) onnx_program = torch.onnx.export(model, (tensor_x,), dynamo=True)
如上所示代码表明,你需要向torch.onnx.export()
提供模型实例及其输入。导出器将返回一个包含导出 ONNX 图和额外信息的 torch.onnx.ONNXProgram
实例。
通过 onnx_program.model_proto
可用的内存模型是一个符合 ONNX IR 规范 的 onnx.ModelProto
对象。然后可以使用 torch.onnx.ONNXProgram.save()
API 将 ONNX 模型序列化为 Protobuf 文件。
onnx_program.save("mlp.onnx")
有两个函数可以基于TorchDynamo引擎将模型导出到ONNX。它们在生成ExportedProgram
的方式上有所不同。torch.onnx.dynamo_export()
是在PyTorch 2.1中引入的,而torch.onnx.export()
则在PyTorch 2.5中进行了扩展,以便轻松地从TorchScript切换到TorchDynamo。要调用前者函数,可以将上一个示例中的最后一行替换为以下内容。
onnx_program = torch.onnx.dynamo_export(model, tensor_x)
使用GUI查看ONNX模型
你可以使用 Netron 查看导出的模型。

注意,每一层都用一个带有 f 图标(位于右上角)的矩形框来表示。

展开后可以查看函数体。

函数体是由一系列的ONNX操作符或其他函数组成。
转换失败时
应再次调用函数torch.onnx.export()
并设置参数 report=True
。这将生成一个 Markdown 报告,以帮助用户解决问题。
函数 torch.onnx.dynamo_export()
使用 ‘SARIF’ 格式生成报告。通过采用 静态分析结果互换格式(即 SARIF),ONNX 诊断不仅限于常规日志,并且帮助用户通过图形界面工具(如 Visual Studio Code 的 SARIF Viewer)调试和改进其模型。
主要的优点如下:
诊断信息以机器可解析的 静态分析结果互换格式 (SARIF) 形式发出。
一种新颖、清晰且结构化的添加和管理诊断规则的方法。
为将来基于诊断结果进行的改进奠定基础。
API参考
- torch.onnx.dynamo_export(model, /, *model_args, export_options=None, **model_kwargs)[源代码]
-
将 torch.nn.Module 导出到 ONNX 图。
- 参数
-
-
model (torch.nn.Module | Callable | torch.export.ExportedProgram) – 需要导出到 ONNX 的 PyTorch 模型。
-
model_args - 传递给
model
的位置参数。 -
model_kwargs – 传入
model
的关键字参数。 -
export_options (ExportOptions | None) – 影响导出到 ONNX 的选项。
-
- 返回值
-
导出的ONNX模型在内存中的表示形式。
- 返回类型
-
ONNXProgram | 任意类型
示例 1 - 简单的导出
class MyModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(2, 2) def forward(self, x, bias=None): out = self.linear(x) out = out + bias return out model = MyModel() kwargs = {"bias": 3.0} args = (torch.randn(2, 2, 2),) onnx_program = torch.onnx.dynamo_export(model, *args, **kwargs).save( "my_simple_model.onnx" )
示例 2 - 动态形状导出
# The previous model can be exported with dynamic shapes export_options = torch.onnx.ExportOptions(dynamic_shapes=True) onnx_program = torch.onnx.dynamo_export( model, *args, **kwargs, export_options=export_options ) onnx_program.save("my_dynamic_model.onnx")
- 类torch.onnx.ExportOptions(*, dynamic_shapes=None, fake_context=None, onnx_registry=None, diagnostic_options=None)
-
TorchDynamo ONNX 导出器的相关选项。
- 变量
-
-
dynamic_shapes (bool|None) – 输入和输出张量的形状信息提示。当值为
None
时,导出器确定最兼容的设置;当值为True
时,所有输入形状被视为动态;当值为False
时,所有输入形状被视为静态。 -
diagnostic_options (DiagnosticOptions) – 用于导出器的诊断选项。
-
fake_context (ONNXFakeContext | None) – 用于符号追踪的伪上下文。
-
onnx_registry (OnnxRegistry | None) – 用于将 ATen 操作符注册到 ONNX 函数的 ONNX 注册表。
-
- torch.onnx.enable_fake_mode()
-
在整个上下文中启用假模式。
内部会创建一个
torch._subclasses.fake_tensor.FakeTensorMode
上下文管理器,将用户输入和模型参数转换为FakeTensor
。torch._subclasses.fake_tensor.FakeTensor
是一个torch.Tensor
,可以在不实际通过分配在meta
设备上的张量进行计算的情况下运行 PyTorch 代码。由于设备上没有实际的数据被分配,此 API 允许导出大型模型而无需考虑执行该模型所需的实际内存占用。当导出的模型过大,无法放入内存时,强烈建议启用模拟模式。
- 返回值
-
一个
ONNXFakeContext
对象。
示例:
# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) >>> import torch >>> import torch.onnx >>> class MyModel(torch.nn.Module): # Dummy model ... def __init__(self) -> None: ... super().__init__() ... self.linear = torch.nn.Linear(2, 2) ... def forward(self, x): ... out = self.linear(x) ... return out >>> with torch.onnx.enable_fake_mode() as fake_context: ... my_nn_module = MyModel() ... arg1 = torch.randn(2, 2, 2) # positional input 1 >>> export_options = torch.onnx.ExportOptions(fake_context=fake_context) >>> onnx_program = torch.onnx.export(my_nn_module, (arg1,), dynamo=True) >>> onnx_program.apply_weights(MyModel().state_dict()) >>> # Saving model WITHOUT initializers >>> onnx_program.save( ... "my_model_without_initializers.onnx", ... include_initializers=False, ... keep_initializers_as_inputs=True, ... ) >>> # Saving model WITH initializers >>> onnx_program.save("my_model_with_initializers.onnx")
警告
此 API 是实验性的,且不支持向下兼容。
- 类torch.onnx.ONNXProgram(model_proto, input_adapter, output_adapter, diagnostic_context, *, fake_context=None, export_exception=None, model_torch=None)
-
一个已导出到ONNX的PyTorch模型的内存中的表示形式。
- 参数
-
-
model_proto (onnx.ModelProto) – 作为
onnx.ModelProto
的导出的 ONNX 模型。 -
input_adapter (io_adapter.InputAdapter) – 将 PyTorch 输入转换为 ONNX 输入的输入适配器。
-
output_adapter (io_adapter.OutputAdapter) – 将 PyTorch 输出转换为 ONNX 输出的适配器。
-
diagnostic_context (diagnostics.DiagnosticContext) – SARIF诊断系统中负责记录错误和元数据的上下文对象。
-
fake_context (ONNXFakeContext | None) – 用于符号追踪的伪上下文。
-
export_exception (Exception | None) – 导出过程中发生异常时的错误信息。
-
- adapt_torch_inputs_to_onnx(*model_args, model_with_state_dict=None, **model_kwargs)[源代码]
-
将 PyTorch 模型的输入转换为导出的 ONNX 模型的输入格式。
由于设计上的差异,PyTorch 模型和导出的 ONNX 模型之间的输入/输出格式通常不一致。例如,PyTorch 模型允许使用 None,而 ONNX 不支持 None。PyTorch 模型可以有张量的嵌套结构,但 ONNX 只支持扁平化的张量等。
实际的适配步骤与每个单独的导出有关。这取决于PyTorch模型、用于导出的具体model_args和model_kwargs参数集,以及导出选项。
此方法会回放导出时记录的自适应步骤。
- 参数
-
-
model_args - PyTorch 模型的参数。
-
model_with_state_dict (torch.nn.Module | Callable | None) – 需要从中获取额外状态的 PyTorch 模型。如果未指定,则使用导出时所用的模型。当启用
enable_fake_mode()
以提取 ONNX 图所需的实际初始化器时,此参数是必需的。 -
model_kwargs – PyTorch 模型的关键词参数。
-
- 返回值
-
从PyTorch模型输入转换而来的张量序列。
- 返回类型
-
Sequence[torch.Tensor 或 int 或 float 或 bool 或 torch.dtype]
示例:
# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) >>> import torch >>> import torch.onnx >>> from typing import Dict, Tuple >>> def func_nested_input( ... x_dict: Dict[str, torch.Tensor], ... y_tuple: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] ... ): ... if "a" in x_dict: ... x = x_dict["a"] ... elif "b" in x_dict: ... x = x_dict["b"] ... else: ... x = torch.randn(3) ... ... y1, (y2, y3) = y_tuple ... ... return x + y1 + y2 + y3 >>> x_dict = {"a": torch.tensor(1.)} >>> y_tuple = (torch.tensor(2.), (torch.tensor(3.), torch.tensor(4.))) >>> onnx_program = torch.onnx.dynamo_export(func_nested_input, x_dict, y_tuple) >>> print(x_dict, y_tuple) {'a': tensor(1.)} (tensor(2.), (tensor(3.), tensor(4.))) >>> print(onnx_program.adapt_torch_inputs_to_onnx(x_dict, y_tuple, model_with_state_dict=func_nested_input)) (tensor(1.), tensor(2.), tensor(3.), tensor(4.))
警告
此 API 是实验性的,且不支持向下兼容。
- adapt_torch_outputs_to_onnx(model_outputs, model_with_state_dict=None)[源代码]
-
将 PyTorch 模型的输出转换为导出的 ONNX 模型的输出格式。
由于设计上的差异,PyTorch 模型和导出的 ONNX 模型之间的输入/输出格式通常不一致。例如,PyTorch 模型允许使用 None,而 ONNX 不支持 None。PyTorch 模型可以有张量的嵌套结构,但 ONNX 只支持扁平化的张量等。
实际的适配步骤与每个单独的导出有关。这取决于PyTorch模型、用于导出的具体model_args和model_kwargs参数集,以及导出选项。
此方法会回放导出时记录的自适应步骤。
- 参数
-
-
model_outputs (Any) – 模型的 PyTorch 输出。
-
model_with_state_dict (torch.nn.Module | Callable | None) – 需要从中获取额外状态的 PyTorch 模型。如果未指定,则使用导出时所用的模型。当启用
enable_fake_mode()
以提取 ONNX 图所需的实际初始化器时,此参数是必需的。
-
- 返回值
-
PyTorch 模型的输出以导出的 ONNX 模型格式呈现。
- 返回类型
-
Sequence[torch.Tensor | int | float | bool]
示例:
# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) >>> import torch >>> import torch.onnx >>> def func_returning_tuples(x, y, z): ... x = x + y ... y = y + z ... z = x + y ... return (x, (y, z)) >>> x = torch.tensor(1.) >>> y = torch.tensor(2.) >>> z = torch.tensor(3.) >>> onnx_program = torch.onnx.dynamo_export(func_returning_tuples, x, y, z) >>> pt_output = func_returning_tuples(x, y, z) >>> print(pt_output) (tensor(3.), (tensor(5.), tensor(8.))) >>> print(onnx_program.adapt_torch_outputs_to_onnx(pt_output, model_with_state_dict=func_returning_tuples)) [tensor(3.), tensor(5.), tensor(8.)]
警告
此 API 是实验性的,且不支持向下兼容。
- apply_weights(state_dict)[源代码]
-
将指定状态字典中的权重应用到ONNX模型上。
:param state_dict: 包含要应用到ONNX模型上的权重的状态字典。
- 属性diagnostic_context:diagnostics.DiagnosticContext
-
与导出相关的诊断信息上下文。
- propertyfake_context:ONNXFakeContext|None
-
与导出相关的虚拟上下文。
- 属性model_proto: onnx.ModelProto
-
导出的 ONNX 模型是一个
onnx.ModelProto
。
- save(destination, *, include_initializers=True, model_state=None)[源代码]
-
将内存中的 ONNX 模型使用指定的
serializer
保存到destination
。- 参数
-
-
destination (str|io.BufferedIOBase) – 保存 ONNX 模型的目标位置。它可以是一个字符串或一个类似文件的对象。当与
model_state
一起使用时,它必须是目标的完整路径字符串。如果 destination 是字符串,则除了将 ONNX 模型保存到文件中之外,模型权重还会存储在与 ONNX 模型相同目录下的单独文件中。例如,对于 destination="/path/model.onnx",初始值将在“/path/”文件夹中与“onnx.model”一起保存。 -
include_initializers (bool) – 是否将初始化器作为外部数据包含在 ONNX 图中。不能与 model_state_dict 同时使用。
-
model_state (dict[str, Any] | str | None) – 包含模型所有权重的 PyTorch 模型的状态字典。它可以是一个字符串(表示检查点路径)或一个包含实际模型状态的字典。支持的文件格式与 torch.load 和 safetensors.safe_open 相同。当使用
enable_fake_mode()
但需要在 ONNX 图中添加真实的初始化器时,此参数是必需的。
-
- save_diagnostics(destination)[源代码]
-
将导出的诊断结果保存为 SARIF 日志,存放在指定的目标路径。
- 参数
-
destination (str) – 保存诊断 SARIF 日志的目标位置,该位置必须具有 .sarif 扩展名。
- 抛出异常
-
ValueError – 如果目标路径没有以.sarif为扩展名。
- 类torch.onnx.ONNXRuntimeOptions(*, session_options=None, execution_providers=None, execution_provider_options=None)
-
用于通过ONNX Runtime影响ONNX模型执行的选项。
- 变量
-
-
session_options (Sequence[onnxruntime.SessionOptions] | None) – ONNX 运行时会话选项。
-
execution_providers (Sequence[str | tuple[str, dict[Any, Any]]] | None) – 在模型执行期间使用的 ONNX Runtime 执行提供程序。
解释更清晰的版本:
execution_providers (Sequence[str | tuple[str, dict[Any, Any]]] | None) – 在模型执行过程中指定要使用的 ONNX Runtime 执行提供程序。
-
execution_provider_options (Sequence[dict[Any, Any]] | None) – ONNX Runtime 的执行提供程序选项。
-
- 类torch.onnx.OnnxExporterError
-
由ONNX导出器引发的错误。这是各种导出器错误的基础类。
- 类torch.onnx.OnnxRegistry
-
用于ONNX函数的注册表。
注册表维护了从合格名称到符号函数的映射,这些函数在一个固定的操作集版本中有效。它支持注册自定义的ONNX脚本函数,并使调度程序能够将调用分配给相应的函数。
- get_op_functions(namespace, op_name, overload=None)[源代码]
-
为给定的操作 torch.ops.<namespace>.<op_name>.<overload> 返回一个 ONNXFunctions 列表。
列表按照注册时间排序,自定义操作符应位于列表的后半部分。
- is_registered_op(namespace, op_name, overload=None)[源代码]
-
返回给定操作是否已注册:torch.ops.
. . 。
- 属性opset_version:int
-
导出器应使用的ONNX操作集版本。默认值为当前支持的最新ONNX操作集版本,即18。随着ONNX的发展,默认版本会逐渐更新。
- register_op(function, namespace, op_name, overload=None, is_complex=False)[源代码]
-
注册一个自定义操作符:torch.ops.<命名空间>.<操作符名称>.<重载>。
- 参数
- 抛出异常
-
ValueError – 如果名称不符合“命名空间::操作”的格式。