TorchScript基的ONNX导出器
注意
要使用 TorchDynamo 而不是 TorchScript 导出 ONNX 模型,请参阅 torch.onnx.dynamo_export()
。
示例:将AlexNet从PyTorch转换为ONNX
这里是一个简单的脚本,用于将预训练的AlexNet导出为名为alexnet.onnx
的ONNX文件。调用torch.onnx.export
会运行模型一次以追踪其执行过程,并将追踪后的模型导出到指定文件:
import torch import torchvision dummy_input = torch.randn(10, 3, 224, 224, device="cuda") model = torchvision.models.alexnet(pretrained=True).cuda() # Providing input and output names sets the display names for values # within the model's graph. Setting these does not change the semantics # of the graph; it is only for readability. # # The inputs to the network consist of the flat list of inputs (i.e. # the values you would pass to the forward() method) followed by the # flat list of parameters. You can partially specify names, i.e. provide # a list here shorter than the number of inputs to the model, and we will # only set that subset of names, starting from the beginning. input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ] output_names = [ "output1" ] torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)
生成的 alexnet.onnx
文件包含一个二进制 协议缓冲区,其中包含了你导出的AlexNet模型的网络结构和参数。设置参数 verbose=True
可以让导出器打印出该模型的人类可读表示:
# These are the inputs and parameters to the network, which have taken on # the names we specified earlier. graph(%actual_input_1 : Float(10, 3, 224, 224) %learned_0 : Float(64, 3, 11, 11) %learned_1 : Float(64) %learned_2 : Float(192, 64, 5, 5) %learned_3 : Float(192) # ---- omitted for brevity ---- %learned_14 : Float(1000, 4096) %learned_15 : Float(1000)) { # Every statement consists of some output tensors (and their types), # the operator to be run (with its attributes, e.g., kernels, strides, # etc.), its input tensors (%actual_input_1, %learned_0, %learned_1) %17 : Float(10, 64, 55, 55) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[11, 11], pads=[2, 2, 2, 2], strides=[4, 4]](%actual_input_1, %learned_0, %learned_1), scope: AlexNet/Sequential[features]/Conv2d[0] %18 : Float(10, 64, 55, 55) = onnx::Relu(%17), scope: AlexNet/Sequential[features]/ReLU[1] %19 : Float(10, 64, 27, 27) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%18), scope: AlexNet/Sequential[features]/MaxPool2d[2] # ---- omitted for brevity ---- %29 : Float(10, 256, 6, 6) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%28), scope: AlexNet/Sequential[features]/MaxPool2d[12] # Dynamic means that the shape is not known. This may be because of a # limitation of our implementation (which we would like to fix in a # future release) or shapes which are truly dynamic. %30 : Dynamic = onnx::Shape(%29), scope: AlexNet %31 : Dynamic = onnx::Slice[axes=[0], ends=[1], starts=[0]](%30), scope: AlexNet %32 : Long() = onnx::Squeeze[axes=[0]](%31), scope: AlexNet %33 : Long() = onnx::Constant[value={9216}](), scope: AlexNet # ---- omitted for brevity ---- %output1 : Float(10, 1000) = onnx::Gemm[alpha=1, beta=1, broadcast=1, transB=1](%45, %learned_14, %learned_15), scope: AlexNet/Sequential[classifier]/Linear[6] return (%output1); }
你也可以使用ONNX库并通过pip
进行安装来验证输出。
pip install onnx
然后,你可以运行:
import onnx # Load the ONNX model model = onnx.load("alexnet.onnx") # Check that the model is well formed onnx.checker.check_model(model) # Print a human readable representation of the graph print(onnx.helper.printable_graph(model.graph))
你也可以使用支持 ONNX 的多种 运行时 来运行导出的模型。例如,在安装了ONNX Runtime之后,你可以加载并运行该模型:
import onnxruntime as ort import numpy as np ort_session = ort.InferenceSession("alexnet.onnx") outputs = ort_session.run( None, {"actual_input_1": np.random.randn(10, 3, 224, 224).astype(np.float32)}, ) print(outputs[0])
这里有一个更详细的教程,介绍如何将模型导出并使用 ONNX Runtime 运行。
追踪 vs 脚本编写
内部,torch.onnx.export()
需要一个torch.jit.ScriptModule
而不是一个torch.nn.Module
。如果传入的模型还不是ScriptModule
,export()
将使用追踪(tracing)来将其转换为一个:
-
追踪: 如果调用
torch.onnx.export()
传入的模块不是已存在的ScriptModule
, 它会先执行与torch.jit.trace()
等效的操作,即使用给定的args
运行模型一次,并记录在此过程中发生的所有操作。这意味着如果你的模型是动态的(例如,行为取决于输入数据),导出的模型将不会捕获这种动态行为。我们建议检查导出的模型并确保操作符看起来合理。追踪会展开循环和 if 语句,导出一个静态图,该图与追踪运行时完全相同。如果你想以具有动态控制流的方式导出你的模型,则需要使用 脚本化。 -
脚本编写:通过脚本编译模型可以保留动态控制流程,并适用于不同大小的输入。要进行脚本编写:
-
使用
torch.jit.script()
生成一个ScriptModule
。 -
调用
torch.onnx.export()
函数,并将ScriptModule
作为模型。虽然仍然需要提供args
参数,但它们仅用于内部生成示例输出,以便捕获输出的类型和形状信息。不会执行追踪。
-
参见TorchScript入门和TorchScript,了解更多详细信息,包括如何结合追踪和脚本编写来满足不同模型的特定需求。
避免坑陷
避免使用NumPy和内置的Python类型
PyTorch 模型可以使用 NumPy 或 Python 类型和函数编写。但在记录期间,任何为 NumPy 或 Python 类型(而不是 torch.Tensor)的变量会被转换成常量。如果这些值应根据输入变化,则会产生错误的结果。
例如,不要在numpy数组上使用numpy函数:
# Bad! Will be replaced with constants during tracing. x, y = np.random.rand(1, 2), np.random.rand(1, 2) np.concatenate((x, y), axis=1)
在torch.Tensor上使用torch运算符:
# Good! Tensor operations will be captured during tracing. x, y = torch.randn(1, 2), torch.randn(1, 2) torch.cat((x, y), dim=1)
不要使用torch.Tensor.item()
(它会将 Tensor 转换为 Python 内置的数值类型):
# Bad! y.item() will be replaced with a constant during tracing. def forward(self, x, y): return x.reshape(y.item(), -1)
利用torch对单元素张量的隐式转换功能:
# Good! y will be preserved as a variable during tracing. def forward(self, x, y): return x.reshape(y, -1)
避免使用 Tensor.data
使用 Tensor.data 字段可能会生成不正确的追踪信息,导致生成的 ONNX 图也不正确。建议改用 torch.Tensor.detach()
。(正在努力完全移除 Tensor.data)。
在追踪模式下使用 tensor.shape 时避免原地操作
在追踪模式下,从 tensor.shape
获取的形状会被作为张量进行追踪,并且共享同一内存空间。这可能导致最终输出值不匹配。为了避免这种情况,在这些场景中不要使用原地操作(inplace operations)。例如,在模型中:
class Model(torch.nn.Module): def forward(self, states): batch_size, seq_length = states.shape[:2] real_seq_length = seq_length real_seq_length += 2 return real_seq_length + seq_length
real_seq_length
和 seq_length
在追踪模式下使用同一块内存。可以通过修改原地操作来避免这种情况:
real_seq_length = real_seq_length + 2
限制
类型
-
只有
torch.Tensors
、可以轻松转换为torch.Tensors
的数值类型(例如 float、int)以及这些类型的元组和列表才能作为模型的输入或输出。在记录模式下,接受 Dict 和 str 类型的输入和输出,但:-
任何依赖字典或字符串输入值的计算将会被替换为在单次追踪执行过程中观察到的常量值。
-
任何输出为字典的对象将被其值的扁平化序列(移除键)悄悄替换。例如,
{"foo": 1, "bar": 2}
变为(1, 2)
。 -
任何输出为字符串的内容都将被默默地移除。
-
-
由于 ONNX 对嵌套序列的支持有限,在脚本模式中,某些涉及元组和列表的操作不被支持。特别是将元组追加到列表的操作不受支持。在追踪模式下,嵌套序列会在追踪过程中自动展平。
运算符实现的差异
不支持的张量索引模式
以下是无法导出的张量索引模式。如果遇到问题且导出的模型中没有包含以下任何不受支持的模式,请确保你使用的是最新版本的 opset_version
。
读取 / 获取
在读取张量时,以下索引模式不受支持:
# Tensor indices that includes negative values. data[torch.tensor([[1, 2], [2, -3]]), torch.tensor([-2, 3])] # Workarounds: use positive index values.
写入 / 设置
在为Tensor写入索引时,以下模式不被支持:
# Multiple tensor indices if any has rank >= 2 data[torch.tensor([[1, 2], [2, 3]]), torch.tensor([2, 3])] = new_data # Workarounds: use single tensor index with rank >= 2, # or multiple consecutive tensor indices with rank == 1. # Multiple tensor indices that are not consecutive data[torch.tensor([2, 3]), :, torch.tensor([1, 2])] = new_data # Workarounds: transpose `data` such that tensor indices are consecutive. # Tensor indices that includes negative values. data[torch.tensor([1, -2]), torch.tensor([-2, 3])] = new_data # Workarounds: use positive index values. # Implicit broadcasting required for new_data. data[torch.tensor([[0, 2], [1, 1]]), 1:3] = new_data # Workarounds: expand new_data explicitly. # Example: # data shape: [3, 4, 5] # new_data shape: [5] # expected new_data shape after broadcasting: [2, 2, 2, 5]
支持操作符的添加
RuntimeError: ONNX export failed: Couldn't export operator foo
当这种情况发生时,你可以尝试以下几种做法:
-
修改模型,使其不再使用该操作符。
-
创建一个符号函数来转换运算符,并将其注册为自定义符号函数。
-
为 PyTorch 贡献代码,将相同的符号函数添加到
torch.onnx
本身。
如果你决定实现一个符号函数(希望你能将其贡献给PyTorch!),可以参考以下方法开始:
ONNX 导出器内部实现
“符号函数”是将一个PyTorch运算符分解成多个ONNX运算符组合的函数。
在导出过程中,导出器按照拓扑顺序访问 TorchScript 图中的每个节点(每个节点包含一个 PyTorch 操作符)。当访问到某个节点时,导出器会查找该操作符的已注册符号函数。符号函数是用 Python 实现的。例如,对于名为 foo
的操作符,其符号函数可能如下所示:
def foo( g, input_0: torch._C.Value, input_1: torch._C.Value) -> Union[None, torch._C.Value, List[torch._C.Value]]: """ Adds the ONNX operations representing this PyTorch function by updating the graph g with `g.op()` calls. Args: g (Graph): graph to write the ONNX representation into. input_0 (Value): value representing the variables which contain the first input for this operator. input_1 (Value): value representing the variables which contain the second input for this operator. Returns: A Value or List of Values specifying the ONNX nodes that compute something equivalent to the original PyTorch operator with the given inputs. None if it cannot be converted to ONNX. """ ...
torch._C
类型是围绕 ir.h 中定义的 C++ 类型创建的 Python 封装。
添加符号函数的过程根据操作符的类型而不同。
ATen操作符
ATen 是 PyTorch 内置的张量库。如果操作符是 ATen 操作符(在 TorchScript 图中带有前缀 aten::
),请确保它尚未被支持。
支持的运算符清单
访问自动生成的支持的 TorchScript 操作符列表,以了解每个 opset_version
中支持的操作符。
支持添加 aten 或量化运算符
如果操作符不在上述列表中:
-
在
torch/onnx/symbolic_opset<version>.py
中定义符号函数,例如torch/onnx/symbolic_opset9.py。确保该函数与ATen函数名称相同,后者可能在torch/_C/_VariableFunctions.pyi
或torch/nn/functional.pyi
中声明(这些文件是在构建时生成的,在你构建PyTorch之前不会出现在你的代码库中)。 -
默认情况下,第一个参数是 ONNX 图。其他参数名必须与
.pyi
文件中的名称完全一致,因为调度是通过关键字参数来实现的。 -
在符号函数中,如果操作符属于ONNX 标准操作符集,我们只需要创建一个节点来表示该 ONNX 操作符。如果不在此标准集中,我们可以组合几个具有相同语义的标准操作符来等同于 ATen 操作符。
这是一个关于如何处理ELU
操作符缺失符号函数的示例。
如果运行以下代码:
print( torch.jit.trace( torch.nn.ELU(), # module torch.ones(1) # example input ).graph )
我们看到了类似这样的内容:
graph(%self : __torch__.torch.nn.modules.activation.___torch_mangle_0.ELU, %input : Float(1, strides=[1], requires_grad=0, device=cpu)): %4 : float = prim::Constant[value=1.]() %5 : int = prim::Constant[value=1]() %6 : int = prim::Constant[value=1]() %7 : Float(1, strides=[1], requires_grad=0, device=cpu) = aten::elu(%input, %4, %5, %6) return (%7)
因为我们在这张图中看到了aten::elu
,所以知道这是一项ATen操作。
我们查阅了ONNX操作符列表,确认Elu
已经在ONNX中进行了标准化。
我们在torch/nn/functional.pyi
中找到了elu
的签名:
def elu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ...
我们在 symbolic_opset9.py
中添加了以下代码行:
def elu(g, input: torch.Value, alpha: torch.Value, inplace: bool = False): return g.op("Elu", input, alpha_f=alpha)
现在,PyTorch可以导出包含aten::elu
操作符的模型了!
请参阅torch/onnx/symbolic_opset*.py
文件以获取更多示例。
torch.autograd.Functions
如果操作符是 torch.autograd.Function
的子类,则有三种导出方法。
静态符号方法
你可以在函数类中添加一个名为symbolic
的静态方法,该方法应返回代表该函数行为的ONNX操作符。例如:
class MyRelu(torch.autograd.Function): @staticmethod def forward(ctx, input: torch.Tensor) -> torch.Tensor: ctx.save_for_backward(input) return input.clamp(min=0) @staticmethod def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value: return g.op("Clip", input, g.op("Constant", value_t=torch.tensor(0, dtype=torch.float)))
内联自动求导函数
当没有为后续的torch.autograd.Function
提供静态符号方法,或者没有提供注册prim::PythonOp
作为自定义符号函数的方法时,torch.onnx.export()
将尝试内联与该torch.autograd.Function
对应的图,并将其分解为函数内部使用的单个操作符。只要这些单独的操作符受到支持,导出过程就应该成功。例如:
class MyLogExp(torch.autograd.Function): @staticmethod def forward(ctx, input: torch.Tensor) -> torch.Tensor: ctx.save_for_backward(input) h = input.exp() return h.log().log()
尽管此模型没有静态符号方法,但它仍然以如下方式被导出:
graph(%input : Float(1, strides=[1], requires_grad=0, device=cpu)): %1 : float = onnx::Exp[](%input) %2 : float = onnx::Log[](%1) %3 : float = onnx::Log[](%2) return (%3)
如果你需要避免内联torch.autograd.Function
,你应该将operator_export_type
设置为 ONNX_FALLTHROUGH
或 ONNX_ATEN_FALLBACK
来导出模型。
自定义操作符
你可以导出自带多种标准ONNX操作组合或由自定义C++后端驱动的自定义操作符的模型。
ONNX-script函数
如果一个操作符不是标准的ONNX操作符,但可以由多个现有的ONNX操作符组合而成,则可以利用ONNX-script 创建外部ONNX函数来支持该操作符。你可以参考以下示例进行导出:
import onnxscript # There are three opset version needed to be aligned # This is (1) the opset version in ONNX function from onnxscript.onnx_opset import opset15 as op opset_version = 15 x = torch.randn(1, 2, 3, 4, requires_grad=True) model = torch.nn.SELU() custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1) @onnxscript.script(custom_opset) def Selu(X): alpha = 1.67326 # auto wrapped as Constants gamma = 1.0507 alphaX = op.CastLike(alpha, X) gammaX = op.CastLike(gamma, X) neg = gammaX * (alphaX * op.Exp(X) - alphaX) pos = gammaX * X zero = op.CastLike(0, X) return op.Where(X <= zero, neg, pos) # setType API provides shape/type to ONNX shape/type inference def custom_selu(g: jit_utils.GraphContext, X): return g.onnxscript_op(Selu, X).setType(X.type()) # Register custom symbolic function # There are three opset version needed to be aligned # This is (2) the opset version in registry torch.onnx.register_custom_op_symbolic( symbolic_name="aten::selu", symbolic_fn=custom_selu, opset_version=opset_version, ) # There are three opset version needed to be aligned # This is (2) the opset version in exporter torch.onnx.export( model, x, "model.onnx", opset_version=opset_version, # only needed if you want to specify an opset version > 1. custom_opsets={"onnx-script": 2} )
上述示例展示了如何将自定义操作符导出到“onnx-script”算子集中。在导出自定义操作符时,可以使用custom_opsets
字典来指定自定义域版本。如果没有指定,默认的自定义算子集版本为1。
注意:请确保将上述示例中的opset版本对齐,并在导出步骤中使用它们。关于如何编写onnx-script函数的示例用法目前处于onnx-script积极开发的测试版阶段,请参考最新的ONNX-script。
C++运算符
如果一个模型使用了如使用自定义C++操作符扩展TorchScript中描述的自定义C++操作符实现,你可以参考以下示例进行导出:
from torch.onnx import symbolic_helper # Define custom symbolic function @symbolic_helper.parse_args("v", "v", "f", "i") def symbolic_foo_forward(g, input1, input2, attr1, attr2): return g.op("custom_domain::Foo", input1, input2, attr1_f=attr1, attr2_i=attr2) # Register custom symbolic function torch.onnx.register_custom_op_symbolic("custom_ops::foo_forward", symbolic_foo_forward, 9) class FooModel(torch.nn.Module): def __init__(self, attr1, attr2): super().__init__() self.attr1 = attr1 self.attr2 = attr2 def forward(self, input1, input2): # Calling custom op return torch.ops.custom_ops.foo_forward(input1, input2, self.attr1, self.attr2) model = FooModel(attr1, attr2) torch.onnx.export( model, (example_input1, example_input1), "model.onnx", # only needed if you want to specify an opset version > 1. custom_opsets={"custom_domain": 2} )
上述示例将该操作作为“custom_domain”操作集中的自定义操作进行导出。在导出自定义操作时,可以使用custom_opsets
字典来指定自定义域的版本。如果没有指定,默认的自定义操作集版本为1。
使用模型的运行时需要支持自定义操作符。参见Caffe2 自定义操作符和ONNX Runtime 自定义操作符,或者查阅你选择的运行时文档。
一次发现所有无法转换的ATen操作
当导出失败是由于存在无法转换的ATen操作时,实际上可能存在多个此类操作,但错误消息仅提到第一个。为了一次性找出所有无法转换的操作,你可以:
# prepare model, args, opset_version ... torch_script_graph, unconvertible_ops = torch.onnx.utils.unconvertible_ops( model, args, opset_version=opset_version ) print(set(unconvertible_ops))
该集合是近似的,因为在转换过程中可能会移除一些操作,无需进行转换。其他一些操作可能只有一部分支持,在特定输入下会失败,但这应该能给你一个大致的概念,哪些操作不受支持。请随时为操作支持请求打开GitHub问题。
常见问题
问题:我已经导出了我的LSTM模型,但是发现它的输入大小是固定的?
追踪器记录示例输入的形状。如果模型需要接受动态形状的输入,请在调用
torch.onnx.export()
时设置dynamic_axes
。
如何导出包含循环的模型?
参见追踪 vs 脚本化。
如何导出具有原生类型输入(如 int、float)的模型?
PyTorch 1.9 添加了对原始数值类型输入的支持。然而,导出器不支持包含字符串输入的模型。
问题:ONNX 是否支持隐式的标量数据类型转换?
虽然 ONNX 标准没有规定,但导出器会尝试处理这部分内容。标量将被导出为常量张量,并且导出器会自动确定其正确数据类型。但在极少数情况下无法做到时,你需要手动指定数据类型,例如 dtype=torch.float32。如果遇到任何错误,请[创建一个 GitHub 问题](https://github.com/pytorch/pytorch/issues)。
问题:可以将Tensor列表导出到ONNX吗?
是的,当
opset_version
大于或等于 11 时,因为 ONNX 在操作集 11 中引入了 Sequence 类型。
Python API
函数
- torch.onnx.export(model, args=(), f=None, *, kwargs=None, export_params=True, verbose=None, input_names=None, output_names=None, opset_version=None, dynamic_axes=None, keep_initializers_as_inputs=False, dynamo=False, external_data=True, dynamic_shapes=None, report=False, verify=False, profile=False, dump_exported_program=False, artifacts_dir='.', fallback=False, training=<TrainingMode.EVAL: 0>, operator_export_type=<OperatorExportTypes.ONNX: 0>, do_constant_folding=True, custom_opsets=None, export_modules_as_functions=False, autograd_inlining=True, **_)[源代码]
-
将模型导出为ONNX格式。
- 参数
-
-
model (torch.nn.Module | torch.export.ExportedProgram | torch.jit.ScriptModule | torch.jit.ScriptFunction) – 需要导出的模型。
-
args (tuple[Any, ...]) – 示例位置输入。任何非张量参数将被硬编码到导出的模型中;任何张量参数将成为导出模型的输入,并按其在元组中的顺序排列。
解释更清晰版本:
args (tuple[Any, ...]) – 示例位置输入。任何非张量参数将被固定在导出的模型中;任何张量参数将成为导出模型的输入,并按其在元组中的顺序排列。
-
f (str | os.PathLike | None) – 输出 ONNX 模型文件的路径,例如 “model.onnx”。
-
export_params (bool) – 如果设置为 false,则不导出参数(权重)。
-
verbose (bool|None) – 是否开启详细日志记录。
-
input_names (Sequence[str] | None) – 按顺序为图的输入节点指定的名称。
-
output_names (Sequence[str] | None) – 按顺序为图的输出节点指定名称。
-
opset_version (int | None) – 默认操作集(ai.onnx)的目标版本。必须大于等于 7。
-
dynamic_axes (Mapping[str, Mapping[int, str]] | Mapping[str, Sequence[int]] | None) –
默认情况下,导出的模型会将所有输入和输出张量的形状与
args
中提供的完全匹配。 若要指定在运行时才确定的张量轴(即动态轴),可以将dynamic_axes
设置为具有以下模式的字典:-
-
KEY (str): 输入或输出的名称。每个名称都必须包含在
input_names
中。 -
output_names
-
KEY (str): 输入或输出的名称。每个名称都必须包含在
-
- VALUE (字典或列表): 如果是字典,键为轴索引,值为轴名称。如果是列表,
-
这是一个列表,其中每个元素都是一个轴索引。
例如:
class SumModule(torch.nn.Module): def forward(self, x): return torch.sum(x, dim=1) torch.onnx.export( SumModule(), (torch.ones(2, 2),), "onnx.pb", input_names=["x"], output_names=["sum"], )
输出:
input { name: "x" ... shape { dim { dim_value: 2 # axis 0 } dim { dim_value: 2 # axis 1 ... output { name: "sum" ... shape { dim { dim_value: 2 # axis 0 ...
当:
torch.onnx.export( SumModule(), (torch.ones(2, 2),), "onnx.pb", input_names=["x"], output_names=["sum"], dynamic_axes={ # dict value: manually named axes "x": {0: "my_custom_axis_name"}, # list value: automatic names "sum": [0], }, )
输出:
input { name: "x" ... shape { dim { dim_param: "my_custom_axis_name" # axis 0 } dim { dim_value: 2 # axis 1 ... output { name: "sum" ... shape { dim { dim_param: "sum_dynamic_axes_1" # axis 0 ...
-
-
keep_initializers_as_inputs (bool) - 是否将初始化器作为输入保留
如果为 True,导出图中的所有初始化器(通常对应于模型权重)将被添加为图的输入。如果为 False,则不将初始化器添加为图的输入,只添加用户指定的输入。
如果你打算在运行时提供模型权重,请将此值设为True。如果权重是固定的,则将其设为False,以便后端或运行时能够进行更高效的优化(如常量折叠)。
-
dynamo (bool) – 是否使用
torch.export
导出的程序来替代TorchScript进行模型导出。 -
external_data (bool) – 是否将模型的权重保存为外部数据文件。对于权重过大,超过ONNX文件大小限制(2GB)的模型,这是必需的。如果设置为False,则权重会与模型架构一起保存在同一个ONNX文件中。
-
dynamic_shapes (dict[str, Any] | tuple[Any, ...] | list[Any] | None) – 模型输入的动态形状字典。更多详情请参阅
torch.export.export()
。此参数仅在 dynamo 为 True 时使用,并且优先推荐使用。在同一时间只应设置一个参数 dynamic_axes 或 dynamic_shapes。 -
report (bool) – 是否为导出过程生成Markdown报告。
-
verify (bool) – 是否通过 ONNX Runtime 来验证导出的模型。
-
profile (bool) – 是否启用导出过程的性能分析。
-
dump_exported_program (bool) – 是否将
torch.export.ExportedProgram
导出到文件。这对于调试导出器非常有用。 -
artifacts_dir (str|os.PathLike) – 用于保存调试工件(例如报告和序列化导出程序)的目录。
-
fallback (bool) — 是否在 Dynamo 出器失败时回退到 TorchScript 导出器。
-
training (_C_onnx.TrainingMode) – 已废弃的选项,请在导出模型前设置模型的训练模式。
-
operator_export_type (_C_onnx.OperatorExportTypes) – 已废弃的选项,当前仅支持 ONNX。
-
do_constant_folding (bool) – 已废弃的选项。导出的图形将始终被优化。
-
custom_opsets (Mapping[str, int] | None) –
已弃用。这是一个字典:
-
KEY (str): opset域的名称
-
VALUE (int): 操作集版本号
如果自定义运算符集合被
模型
引用但未在此字典中提及,则其版本将默认设为1。仅通过此参数指定自定义运算符集合的域名和版本。 -
-
export_modules_as_functions (bool 或 Collection[type[torch.nn.Module]]) –
已废弃的选项。
启用将所有
nn.Module
的前向调用导出为ONNX中的本地函数的功能标志。或者,使用集合来指定要作为ONNX中本地函数导出的具体模块类型。此功能需要opset_version
>= 15,否则导出会失败。这是因为当opset_version
< 15时,IR版本会小于8,这意味着不支持本地函数。模块变量将作为函数属性进行导出。函数属性分为两类。1. 注解属性:通过PEP 526风格进行类型注解的类变量将被导出为属性。虽然这些注解属性不在ONNX局部函数的子图中使用(因为它们不是由PyTorch JIT跟踪创建的),但消费者可能会利用这些属性来决定是否用特定的融合内核替换该函数。
2. 推断属性:模块内操作符使用的变量,其名称带有前缀“inferred::”。这有助于将其与从Python模块注释中获取的预定义属性区分开来。推断属性在ONNX局部函数的子图内部使用。
-
False
(默认情况下):将nn.Module
的前向呼叫导出为细粒度节点。 -
True
: 将所有nn.Module
的前向呼叫导出为局部函数节点。 -
-
类型为 nn.Module 的集合:将
nn.Module
的 forward 方法调用导出为局部函数节点。 -
仅当 nn.Module 的类型存在于集合中时。
-
类型为 nn.Module 的集合:将
-
-
autograd_inlining (bool) – 已废弃。用于控制是否内联自动微分函数的标志。更多详情,请参阅 https://github.com/pytorch/pytorch/pull/74765。
-
- 返回类型
-
任意类型 | None
- torch.onnx.export_to_pretty_string(model, args, export_params=True, verbose=False, training=<TrainingMode.EVAL: 0>, input_names=None, output_names=None, operator_export_type=<OperatorExportTypes.ONNX: 0>, export_type=None, google_printer=False, opset_version=None, keep_initializers_as_inputs=None, custom_opsets=None, add_node_names=True, do_constant_folding=True, dynamic_axes=None)[源代码]
-
自版本 2.5 起弃用: 将在未来版本中移除。建议使用
onnx.printer.to_text()
替代。类似于
export()
,但返回 ONNX 模型的文本表示形式。仅列出参数的不同之处,其他所有参数与export()
相同。
- torch.onnx.register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version)[源代码]
-
为自定义运算符注册一个符号函数。
当用户为自定义或贡献的操作符注册符号时,强烈建议通过 setType API 为此操作符添加形状推断。否则,在某些极端情况下,导出的图可能会出现不正确的形状推断。例如,在 test_operators.py 中有一个名为test_aten_embedding_2 的示例。
参见模块文档中的“自定义操作符”,了解示例用法。
- torch.onnx.unregister_custom_op_symbolic(symbolic_name, opset_version)[源代码]
-
取消注册
符号名
。参见模块文档中的“自定义操作符”,了解示例用法。
- torch.onnx.select_model_mode_for_export(model, mode)[源代码]
-
一个上下文管理器,用于临时将
model
的训练模式设置为mode
,并在退出with
代码块时进行重置。
- torch.onnx.enable_log()[源代码]
-
开启ONNX日志记录。
- torch.onnx.disable_log()[源代码]
-
关闭ONNX的日志记录。
- torch.onnx.verification.find_mismatch(model, input_args, do_constant_folding=True, training=<TrainingMode.EVAL: 0>, opset_version=None, keep_initializers_as_inputs=True, verbose=False, options=None)[源代码]
-
找出原始模型和导出模型之间所有的不一致之处。
试验性功能。API可能随时变更。
此工具用于帮助调试原始PyTorch模型与导出的ONNX模型之间的差异。它通过二分搜索模型图,找到并展示导致差异的最小部分。
- 参数
-
-
model (torch.nn.Module | torch.jit.ScriptModule) – 需要导出的模型。
-
input_args (tuple[Any, ...]) – 模型的输入参数列表。
-
do_constant_folding (bool) – 与
torch.onnx.export()
中的 do_constant_folding 参数相同。 -
training (_C_onnx.TrainingMode) – 与
torch.onnx.export()
中的 training 参数相同。 -
opset_version (int|None) – 与
torch.onnx.export()
中的opset_version
参数相同。 -
keep_initializers_as_inputs (bool) – 与
torch.onnx.export()
中的 keep_initializers_as_inputs 参数相同。 -
verbose (bool) – 与
torch.onnx.export()
中的 verbose 参数相同。 -
options (VerificationOptions | None) – 用于差异验证的选项。
-
- 返回值
-
一个包含不匹配信息的 GraphInfo 对象。
- 返回类型
示例:
>>> import torch >>> import torch.onnx.verification >>> torch.manual_seed(0) >>> opset_version = 15 >>> # Define a custom symbolic function for aten::relu. >>> # The custom symbolic function is incorrect, which will result in mismatches. >>> def incorrect_relu_symbolic_function(g, self): ... return self >>> torch.onnx.register_custom_op_symbolic( ... "aten::relu", ... incorrect_relu_symbolic_function, ... opset_version=opset_version, ... ) >>> class Model(torch.nn.Module): ... def __init__(self) -> None: ... super().__init__() ... self.layers = torch.nn.Sequential( ... torch.nn.Linear(3, 4), ... torch.nn.ReLU(), ... torch.nn.Linear(4, 5), ... torch.nn.ReLU(), ... torch.nn.Linear(5, 6), ... ) ... def forward(self, x): ... return self.layers(x) >>> graph_info = torch.onnx.verification.find_mismatch( ... Model(), ... (torch.randn(2, 3),), ... opset_version=opset_version, ... ) ===================== Mismatch info for graph partition : ====================== ================================ Mismatch error ================================ Tensor-likes are not close! Mismatched elements: 12 / 12 (100.0%) Greatest absolute difference: 0.2328854203224182 at index (1, 2) (up to 1e-07 allowed) Greatest relative difference: 0.699536174352349 at index (1, 3) (up to 0.001 allowed) ==================================== Tree: ===================================== 5 X __2 X __1 \u2713 id: | id: 0 | id: 00 | | | |__1 X (aten::relu) | id: 01 | |__3 X __1 \u2713 id: 1 | id: 10 | |__2 X __1 X (aten::relu) id: 11 | id: 110 | |__1 \u2713 id: 111 =========================== Mismatch leaf subgraphs: =========================== ['01', '110'] ============================= Mismatch node kinds: ============================= {'aten::relu': 2}
类
Torch中定义的标量类型。 |
|
GraphInfo 包含 TorchScript 图及其转换为 ONNX 图的验证信息。 |
|
用于 ONNX 导出验证的选项。 |