torch.onnx

概述

开放神经网络交换 (ONNX) 是一种用于表示机器学习模型的开放标准格式。 torch.onnx 模块从原生 PyTorch torch.nn.Module 模型中捕获计算图,并将其转换为ONNX 图形

导出的模型可以被任何支持 ONNX 的运行时使用,例如微软的 ONNX 运行时

你可以使用两种类型的 ONNX 导出 API,如下所示。 这两种类型都可以通过函数 torch.onnx.export() 调用。 下面的示例展示了如何导出一个简单的模型。

import torch

class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 128, 5)

    def forward(self, x):
        return torch.relu(self.conv1(x))

input_tensor = torch.rand((1, 1, 128, 128), dtype=torch.float32)

model = MyModel()

torch.onnx.export(
    model,                  # model to export
    (input_tensor,),        # inputs of the model,
    "my_model.onnx",        # filename of the ONNX model
    input_names=["input"],  # Rename inputs for the ONNX model
    dynamo=True             # True or False to select the exporter to use
)

接下来的部分将介绍导出器的两个版本。

基于 TorchDynamo 的 ONNX 导出器

基于TorchDynamo的ONNX导出器是PyTorch 2.1及其后续版本中最新的(Beta版)导出器。

TorchDynamo 引擎利用 Python 的帧评估 API,将字节码动态转换为 FX 图。然后对生成的 FX 图进行优化,并最终翻译成 ONNX 图。

这种方法的主要优势在于,它通过字节码分析来捕获FX图,从而保留了模型的动态特性,而不会像传统静态追踪技术那样丢失这些特性。

了解基于 TorchDynamo 的 ONNX 导出器的更多信息

TorchScript基的ONNX导出器

基于TorchScript的ONNX导出器从PyTorch 1.2.0版本开始可用

TorchScript 使用 torch.jit.trace() 追踪模型并捕获静态计算图。

因此,生成的图表存在一些限制:

  • 它不记录任何控制流程,例如 if 语句或循环。

  • 未处理训练评估模式间的差异;

  • 无法真正处理动态输入

为了弥补静态追踪的限制,导出器还支持通过torch.jit.script()进行TorchScript脚本编写。这增加了对数据依赖控制流的支持(例如)。然而,由于TorchScript是Python语言的一个子集,并非所有功能都受支持,比如就地操作。

了解有关基于TorchScript的ONNX导出器的更多信息

贡献与开发

ONNX导出器是一个社区项目,欢迎各位贡献。我们遵循PyTorch的贡献指南,但你可能也对我们的开发维基感兴趣。

本页目录