基于TorchDynamo的ONNX exporter

警告

TorchDynamo 的 ONNX 导出器是一款迅速发展的 beta 技术。

概述

ONNX 导出器利用 TorchDynamo 引擎,通过挂钩到 Python 的帧评估 API,将字节码动态地转换为 FX 图。随后,在将其最终翻译成 ONNX 图之前,会对生成的 FX 图进行优化。

这种方法的主要优势在于,它通过字节码分析来捕获FX图,从而保留了模型的动态特性,而不会像传统静态追踪技术那样丢失这些特性。

导出器被设计为模块化和可扩展,包含以下几个组成部分:

  • ONNX 导出器: Exporter 主类,负责协调导出过程。

  • ONNX 导出选项: ExportOptions 提供了一组用于控制导出过程的选项。

  • ONNX 注册表: OnnxRegistry 是用于存储 ONNX 操作符和函数的注册表。

  • FX 图提取器: FXGraphExtractor 用于从 PyTorch 模型中提取 FX 图。

  • 模拟模式: ONNXFakeContext 是一个上下文管理器,用于在大规模模型中启用模拟模式。

  • ONNX程序: ONNXProgram 是导出器生成的结果,包含了导出的ONNX图以及相关的诊断信息。

  • ONNX 诊断选项: DiagnosticOptions 提供了一系列选项来控制导出器生成的诊断信息。

依赖项

ONNX 导出器依赖于一些额外的 Python 包:

它们可以使用 pip 进行安装:

pipinstall--upgradeonnxonnxscript

ONNX Runtime 可以用于在多种处理器上运行模型。

一个简单的例子

以下是一个使用导出器API的演示,示例中采用了一个简单的多层感知器(MLP):

import torch
import torch.nn as nn

class MLPModel(nn.Module):
  def __init__(self):
      super().__init__()
      self.fc0 = nn.Linear(8, 8, bias=True)
      self.fc1 = nn.Linear(8, 4, bias=True)
      self.fc2 = nn.Linear(4, 2, bias=True)
      self.fc3 = nn.Linear(2, 2, bias=True)

  def forward(self, tensor_x: torch.Tensor):
      tensor_x = self.fc0(tensor_x)
      tensor_x = torch.sigmoid(tensor_x)
      tensor_x = self.fc1(tensor_x)
      tensor_x = torch.sigmoid(tensor_x)
      tensor_x = self.fc2(tensor_x)
      tensor_x = torch.sigmoid(tensor_x)
      output = self.fc3(tensor_x)
      return output

model = MLPModel()
tensor_x = torch.rand((97, 8), dtype=torch.float32)
onnx_program = torch.onnx.export(model, (tensor_x,), dynamo=True)

如上所示代码表明,你需要向torch.onnx.export() 提供模型实例及其输入。导出器将返回一个包含导出 ONNX 图和额外信息的 torch.onnx.ONNXProgram 实例。

通过 onnx_program.model_proto 可用的内存模型是一个符合 ONNX IR 规范onnx.ModelProto 对象。然后可以使用 torch.onnx.ONNXProgram.save() API 将 ONNX 模型序列化为 Protobuf 文件

onnx_program.save("mlp.onnx")

有两个函数可以基于TorchDynamo引擎将模型导出到ONNX。它们在生成ExportedProgram的方式上有所不同。torch.onnx.dynamo_export()是在PyTorch 2.1中引入的,而torch.onnx.export()则在PyTorch 2.5中进行了扩展,以便轻松地从TorchScript切换到TorchDynamo。要调用前者函数,可以将上一个示例中的最后一行替换为以下内容。

onnx_program = torch.onnx.dynamo_export(model, tensor_x)

使用GUI查看ONNX模型

你可以使用 Netron 查看导出的模型。

MLP model as viewed using Netron

注意,每一层都用一个带有 f 图标(位于右上角)的矩形框来表示。

ONNX function highlighted on MLP model

展开后可以查看函数体。

ONNX function body

函数体是由一系列的ONNX操作符或其他函数组成。

转换失败时

应再次调用函数torch.onnx.export() 并设置参数 report=True。这将生成一个 Markdown 报告,以帮助用户解决问题。

函数 torch.onnx.dynamo_export() 使用 ‘SARIF’ 格式生成报告。通过采用 静态分析结果互换格式(即 SARIF),ONNX 诊断不仅限于常规日志,并且帮助用户通过图形界面工具(如 Visual Studio Code 的 SARIF Viewer)调试和改进其模型。

主要的优点如下:

  • 诊断信息以机器可解析的 静态分析结果互换格式 (SARIF) 形式发出。

  • 一种新颖、清晰且结构化的添加和管理诊断规则的方法。

  • 为将来基于诊断结果进行的改进奠定基础。

API参考

torch.onnx.dynamo_export(model, /, *model_args, export_options=None, **model_kwargs)[源代码]

将 torch.nn.Module 导出到 ONNX 图。

参数
  • model (torch.nn.Module | Callable | torch.export.ExportedProgram) – 需要导出到 ONNX 的 PyTorch 模型。

  • model_args - 传递给 model 的位置参数。

  • model_kwargs – 传入 model 的关键字参数。

  • export_options (ExportOptions | None) – 影响导出到 ONNX 的选项。

返回值

导出的ONNX模型在内存中的表示形式。

返回类型

ONNXProgram | 任意类型

示例 1 - 简单的导出

class MyModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = torch.nn.Linear(2, 2)

    def forward(self, x, bias=None):
        out = self.linear(x)
        out = out + bias
        return out


model = MyModel()
kwargs = {"bias": 3.0}
args = (torch.randn(2, 2, 2),)
onnx_program = torch.onnx.dynamo_export(model, *args, **kwargs).save(
    "my_simple_model.onnx"
)

示例 2 - 动态形状导出

# The previous model can be exported with dynamic shapes
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
onnx_program = torch.onnx.dynamo_export(
    model, *args, **kwargs, export_options=export_options
)
onnx_program.save("my_dynamic_model.onnx")
torch.onnx.ExportOptions(*, dynamic_shapes=None, fake_context=None, onnx_registry=None, diagnostic_options=None)

TorchDynamo ONNX 导出器的相关选项。

变量
  • dynamic_shapes (bool|None) – 输入和输出张量的形状信息提示。当值为 None 时,导出器确定最兼容的设置;当值为 True 时,所有输入形状被视为动态;当值为 False 时,所有输入形状被视为静态。

  • diagnostic_options (DiagnosticOptions) – 用于导出器的诊断选项。

  • fake_context (ONNXFakeContext | None) – 用于符号追踪的伪上下文。

  • onnx_registry (OnnxRegistry | None) – 用于将 ATen 操作符注册到 ONNX 函数的 ONNX 注册表。

torch.onnx.enable_fake_mode()

在整个上下文中启用假模式。

内部会创建一个 torch._subclasses.fake_tensor.FakeTensorMode 上下文管理器,将用户输入和模型参数转换为 FakeTensor

torch._subclasses.fake_tensor.FakeTensor 是一个 torch.Tensor,可以在不实际通过分配在 meta 设备上的张量进行计算的情况下运行 PyTorch 代码。由于设备上没有实际的数据被分配,此 API 允许导出大型模型而无需考虑执行该模型所需的实际内存占用。

当导出的模型过大,无法放入内存时,强烈建议启用模拟模式。

返回值

一个 ONNXFakeContext 对象。

示例:

# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
>>> import torch
>>> import torch.onnx
>>> class MyModel(torch.nn.Module):  # Dummy model
...     def __init__(self) -> None:
...         super().__init__()
...         self.linear = torch.nn.Linear(2, 2)
...     def forward(self, x):
...         out = self.linear(x)
...         return out
>>> with torch.onnx.enable_fake_mode() as fake_context:
...     my_nn_module = MyModel()
...     arg1 = torch.randn(2, 2, 2)  # positional input 1
>>> export_options = torch.onnx.ExportOptions(fake_context=fake_context)
>>> onnx_program = torch.onnx.export(my_nn_module, (arg1,), dynamo=True)
>>> onnx_program.apply_weights(MyModel().state_dict())
>>> # Saving model WITHOUT initializers
>>> onnx_program.save(
...     "my_model_without_initializers.onnx",
...     include_initializers=False,
...     keep_initializers_as_inputs=True,
... )
>>> # Saving model WITH initializers
>>> onnx_program.save("my_model_with_initializers.onnx")

警告

此 API 是实验性的,且支持向下兼容。

torch.onnx.ONNXProgram(model_proto, input_adapter, output_adapter, diagnostic_context, *, fake_context=None, export_exception=None, model_torch=None)

一个已导出到ONNX的PyTorch模型的内存中的表示形式。

参数
  • model_proto (onnx.ModelProto) – 作为 onnx.ModelProto 的导出的 ONNX 模型。

  • input_adapter (io_adapter.InputAdapter) – 将 PyTorch 输入转换为 ONNX 输入的输入适配器。

  • output_adapter (io_adapter.OutputAdapter) – 将 PyTorch 输出转换为 ONNX 输出的适配器。

  • diagnostic_context (diagnostics.DiagnosticContext) – SARIF诊断系统中负责记录错误和元数据的上下文对象。

  • fake_context (ONNXFakeContext | None) – 用于符号追踪的伪上下文。

  • export_exception (Exception | None) – 导出过程中发生异常时的错误信息。

adapt_torch_inputs_to_onnx(*model_args, model_with_state_dict=None, **model_kwargs)[源代码]

将 PyTorch 模型的输入转换为导出的 ONNX 模型的输入格式。

由于设计上的差异,PyTorch 模型和导出的 ONNX 模型之间的输入/输出格式通常不一致。例如,PyTorch 模型允许使用 None,而 ONNX 不支持 None。PyTorch 模型可以有张量的嵌套结构,但 ONNX 只支持扁平化的张量等。

实际的适配步骤与每个单独的导出有关。这取决于PyTorch模型、用于导出的具体model_args和model_kwargs参数集,以及导出选项。

此方法会回放导出时记录的自适应步骤。

参数
  • model_args - PyTorch 模型的参数。

  • model_with_state_dict (torch.nn.Module | Callable | None) – 需要从中获取额外状态的 PyTorch 模型。如果未指定,则使用导出时所用的模型。当启用 enable_fake_mode() 以提取 ONNX 图所需的实际初始化器时,此参数是必需的。

  • model_kwargs – PyTorch 模型的关键词参数。

返回值

从PyTorch模型输入转换而来的张量序列。

返回类型

Sequence[torch.Tensorintfloatbooltorch.dtype]

示例:

# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
>>> import torch
>>> import torch.onnx
>>> from typing import Dict, Tuple
>>> def func_nested_input(
...     x_dict: Dict[str, torch.Tensor],
...     y_tuple: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
... ):
...     if "a" in x_dict:
...         x = x_dict["a"]
...     elif "b" in x_dict:
...         x = x_dict["b"]
...     else:
...         x = torch.randn(3)
...
...     y1, (y2, y3) = y_tuple
...
...     return x + y1 + y2 + y3
>>> x_dict = {"a": torch.tensor(1.)}
>>> y_tuple = (torch.tensor(2.), (torch.tensor(3.), torch.tensor(4.)))
>>> onnx_program = torch.onnx.dynamo_export(func_nested_input, x_dict, y_tuple)
>>> print(x_dict, y_tuple)
{'a': tensor(1.)} (tensor(2.), (tensor(3.), tensor(4.)))
>>> print(onnx_program.adapt_torch_inputs_to_onnx(x_dict, y_tuple, model_with_state_dict=func_nested_input))
(tensor(1.), tensor(2.), tensor(3.), tensor(4.))

警告

此 API 是实验性的,且支持向下兼容。

adapt_torch_outputs_to_onnx(model_outputs, model_with_state_dict=None)[源代码]

将 PyTorch 模型的输出转换为导出的 ONNX 模型的输出格式。

由于设计上的差异,PyTorch 模型和导出的 ONNX 模型之间的输入/输出格式通常不一致。例如,PyTorch 模型允许使用 None,而 ONNX 不支持 None。PyTorch 模型可以有张量的嵌套结构,但 ONNX 只支持扁平化的张量等。

实际的适配步骤与每个单独的导出有关。这取决于PyTorch模型、用于导出的具体model_args和model_kwargs参数集,以及导出选项。

此方法会回放导出时记录的自适应步骤。

参数
  • model_outputs (Any) – 模型的 PyTorch 输出。

  • model_with_state_dict (torch.nn.Module | Callable | None) – 需要从中获取额外状态的 PyTorch 模型。如果未指定,则使用导出时所用的模型。当启用 enable_fake_mode() 以提取 ONNX 图所需的实际初始化器时,此参数是必需的。

返回值

PyTorch 模型的输出以导出的 ONNX 模型格式呈现。

返回类型

Sequence[torch.Tensor | int | float | bool]

示例:

# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
>>> import torch
>>> import torch.onnx
>>> def func_returning_tuples(x, y, z):
...     x = x + y
...     y = y + z
...     z = x + y
...     return (x, (y, z))
>>> x = torch.tensor(1.)
>>> y = torch.tensor(2.)
>>> z = torch.tensor(3.)
>>> onnx_program = torch.onnx.dynamo_export(func_returning_tuples, x, y, z)
>>> pt_output = func_returning_tuples(x, y, z)
>>> print(pt_output)
(tensor(3.), (tensor(5.), tensor(8.)))
>>> print(onnx_program.adapt_torch_outputs_to_onnx(pt_output, model_with_state_dict=func_returning_tuples))
[tensor(3.), tensor(5.), tensor(8.)]

警告

此 API 是实验性的,且支持向下兼容。

apply_weights(state_dict)[源代码]

将指定状态字典中的权重应用到ONNX模型上。
:param state_dict: 包含要应用到ONNX模型上的权重的状态字典。

属性diagnostic_context:diagnostics.DiagnosticContext

与导出相关的诊断信息上下文。

propertyfake_context:ONNXFakeContext|None

与导出相关的虚拟上下文。

属性model_proto: onnx.ModelProto

导出的 ONNX 模型是一个 onnx.ModelProto

save(destination, *, include_initializers=True, model_state=None)[源代码]

将内存中的 ONNX 模型使用指定的 serializer 保存到 destination

参数
  • destination (str|io.BufferedIOBase) – 保存 ONNX 模型的目标位置。它可以是一个字符串或一个类似文件的对象。当与 model_state 一起使用时,它必须是目标的完整路径字符串。如果 destination 是字符串,则除了将 ONNX 模型保存到文件中之外,模型权重还会存储在与 ONNX 模型相同目录下的单独文件中。例如,对于 destination="/path/model.onnx",初始值将在“/path/”文件夹中与“onnx.model”一起保存。

  • include_initializers (bool) – 是否将初始化器作为外部数据包含在 ONNX 图中。不能与 model_state_dict 同时使用。

  • model_state (dict[str, Any] | str | None) – 包含模型所有权重的 PyTorch 模型的状态字典。它可以是一个字符串(表示检查点路径)或一个包含实际模型状态的字典。支持的文件格式与 torch.loadsafetensors.safe_open 相同。当使用 enable_fake_mode() 但需要在 ONNX 图中添加真实的初始化器时,此参数是必需的。

save_diagnostics(destination)[源代码]

将导出的诊断结果保存为 SARIF 日志,存放在指定的目标路径。

参数

destination (str) – 保存诊断 SARIF 日志的目标位置,该位置必须具有 .sarif 扩展名。

抛出异常

ValueError – 如果目标路径没有以.sarif为扩展名。

torch.onnx.ONNXRuntimeOptions(*, session_options=None, execution_providers=None, execution_provider_options=None)

用于通过ONNX Runtime影响ONNX模型执行的选项。

变量
  • session_options (Sequence[onnxruntime.SessionOptions] | None) – ONNX 运行时会话选项。

  • execution_providers (Sequence[str | tuple[str, dict[Any, Any]]] | None) – 在模型执行期间使用的 ONNX Runtime 执行提供程序。

    解释更清晰的版本:

    execution_providers (Sequence[str | tuple[str, dict[Any, Any]]] | None) – 在模型执行过程中指定要使用的 ONNX Runtime 执行提供程序。

  • execution_provider_options (Sequence[dict[Any, Any]] | None) – ONNX Runtime 的执行提供程序选项。

torch.onnx.OnnxExporterError

由ONNX导出器引发的错误。这是各种导出器错误的基础类。

torch.onnx.OnnxRegistry

用于ONNX函数的注册表。

注册表维护了从合格名称到符号函数的映射,这些函数在一个固定的操作集版本中有效。它支持注册自定义的ONNX脚本函数,并使调度程序能够将调用分配给相应的函数。

get_op_functions(namespace, op_name, overload=None)[源代码]

为给定的操作 torch.ops.<namespace>.<op_name>.<overload> 返回一个 ONNXFunctions 列表。

列表按照注册时间排序,自定义操作符应位于列表的后半部分。

参数
  • 命名空间 (str) – 指定要获取的操作符的命名空间。

  • op_name (str) – 指定要获取的操作符的名称。

  • overload (str | None) – 要获取的操作符的重载。如果是默认重载,则保持为 None。

返回值

如果名称在注册表中存在,则返回相应的 ONNXFunctions 列表;否则返回 None。

返回类型

列表 [registration.ONNXFunction] 或者 None

is_registered_op(namespace, op_name, overload=None)[源代码]

返回给定操作是否已注册:torch.ops. . .

参数
  • 命名空间 (str) – 需要检查的操作符的命名空间。

  • op_name (str) – 需要检查的操作符名称。

  • overload (str | None) – 要检查的操作符的重载。如果是默认重载,则保留为 None。

返回值

如果是已注册的操作,则返回真,否则返回假。

返回类型

bool

属性opset_version:int

导出器应使用的ONNX操作集版本。默认值为当前支持的最新ONNX操作集版本,即18。随着ONNX的发展,默认版本会逐渐更新。

register_op(function, namespace, op_name, overload=None, is_complex=False)[源代码]

注册一个自定义操作符:torch.ops.<命名空间>.<操作符名称>.<重载>。

参数
  • function (onnxscript.OnnxFunctiononnxscript.TracedOnnxFunction) – 需要注册的 onnx-script 函数。

  • 命名空间 (str) – 操作符要注册到的命名空间。

  • op_name (str) – 注册的操作符名称。

  • overload (str | None) – 要注册的操作符的重载。如果是默认重载,则保留为 None。

  • is_complex (bool) – 表示该函数是否用于处理复数类型的输入。

抛出异常

ValueError – 如果名称不符合“命名空间::操作”的格式。

torch.onnx.DiagnosticOptions(verbosity_level=20, warnings_as_errors=False)

诊断上下文的相关选项。

变量
  • verbosity_level (int) – 设置每个诊断的日志信息量,相当于 Python 日志模块中的 ‘level’ 参数。

  • warnings_as_errors (bool) – 当设置为 True 时,将警告视为错误。

本页目录