torch.jit.trace
- torch.jit.trace(func, example_inputs=None, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>, example_kwarg_inputs=None, _store_inputs=True)[源代码]
-
跟踪一个函数,并返回一个可执行文件或
ScriptFunction
,该文件会通过即时编译进行优化。追踪适用于仅操作
Tensor
及其列表、字典和元组的代码。使用torch.jit.trace和torch.jit.trace_module,你可以将现有的模块或Python函数转换为TorchScript
ScriptFunction
或ScriptModule
。你需要提供示例输入,系统会运行该函数并记录所有张量上的操作。-
独立函数的记录结果会产生 ScriptFunction。
-
记录 nn.Module.forward 或 nn.Module 后,会生成 ScriptModule。
此模块还包括原始模块的所有参数。
警告
追踪仅记录那些不依赖于数据(例如,不在张量中进行条件判断)且没有任何未跟踪的外部依赖项(例如,执行输入/输出或访问全局变量)的函数和模块。它只记录在给定张量上运行特定函数时的操作。因此,返回的ScriptModule将在任何输入上始终运行相同的追踪图。当你的模块预期根据输入和/或模块状态的不同而执行不同的操作集时,这会产生一些重要的影响。例如,
-
追踪不会记录任何控制流程,如 if 语句或循环。当这种控制流程在整个模块中是恒定的时,这是可以接受的,并且通常会内联这些决策。但有时控制流实际上是模型的一部分。例如,递归网络会对输入序列(可能是动态长度)进行循环处理。
-
在返回的
ScriptModule
中,那些在training
和eval
模式下行为不同的操作将始终按照记录时的模式运行,无论当前ScriptModule处于哪种模式。
在这种情况下,使用追踪是不合适的,而
脚本化
是一个更好的选择。如果你对这样的模型进行追踪,可能会在后续调用该模型时默默地得到错误的结果。当执行可能导致生成错误跟踪的操作时,追踪器会尝试发出警告。- 参数
-
func (callable 或 torch.nn.Module) – 一个 Python 函数或 torch.nn.Module,将使用 example_inputs 运行。函数的参数和返回值必须是张量或包含张量的(可能是嵌套的)元组。当传递模块时,只有
forward
方法会被运行并进行追踪(详情请参阅torch.jit.trace
)。 - 关键字参数
-
-
example_inputs (元组 或 torch.Tensor 或 None, 可选) – 用于在跟踪过程中传递给函数的示例输入元组。默认值为
None
。此参数或example_kwarg_inputs
中的一个必须指定。生成的跟踪可以使用不同类型的输入和形状运行,前提是所跟踪的操作支持这些类型和形状。example_inputs 也可以是一个单独的 Tensor,在这种情况下它会被自动包装成一个元组。当值为 None 时,example_kwarg_inputs
应该被指定。 -
check_trace (
bool
, optional) – 检查相同的输入通过跟踪代码是否产生相同的输出。默认值为True
。如果你的网络包含非确定性操作,或者你确信即使检查器失败网络也是正确的,你可以禁用此选项。 -
check_inputs (list of tuples, optional) – 一个包含输入参数元组的列表,用于检查跟踪是否符合预期。每个元组等同于在
example_inputs
中指定的一组输入参数。为了获得最佳效果,请传入一组具有代表性的检查输入,这些输入涵盖了你期望网络看到的各种形状和类型。如果未指定,则使用原始的example_inputs
进行检查。 -
check_tolerance (float, 可选) – 在检查过程中使用的浮点数比较容差。如果由于某些已知原因(如操作符融合)导致结果出现数值差异,可以使用此参数来放松检查的严格性。
-
strict (
bool
, 可选) – 是否以严格模式运行追踪器(默认值:True
)。仅在你希望追踪器记录可变容器类型(当前为list
/dict
),并且确信你在问题中使用的容器是一个constant
结构且不会用作控制流(如 if、for 语句)条件时,才关闭此模式。 -
example_kwarg_inputs (dict, 可选) – 此参数包含一组关键字参数,这些参数是示例输入,在跟踪函数时将传递给该函数。默认值:
None
。此参数或example_inputs
必须指定一个。字典中的键会根据被跟踪函数的参数名称进行解包。如果字典中的键与被跟踪函数的参数名不匹配,则会引发运行时异常。
-
- 返回值
-
如果 func 是 nn.Module 或 nn.Module 的
forward
方法,trace 将返回一个包含单个forward
方法的ScriptModule
对象,该方法包含了追踪到的代码。返回的 ScriptModule 具有与原始nn.Module
相同的一组子模块和参数。如果func
是一个独立函数,则trace
将返回 ScriptFunction。
示例(跟踪一个函数):
import torch def foo(x, y): return 2 * x + y # Run `foo` with the provided inputs and record the tensor operations traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3))) # `traced_foo` can now be run with the TorchScript interpreter or saved # and loaded in a Python-free environment
示例(跟踪现有模块):
import torch import torch.nn as nn class Net(nn.Module): def __init__(self) -> None: super().__init__() self.conv = nn.Conv2d(1, 1, 3) def forward(self, x): return self.conv(x) n = Net() example_weight = torch.rand(1, 1, 3, 3) example_forward_input = torch.rand(1, 1, 3, 3) # Trace a specific method and construct `ScriptModule` with # a single `forward` method module = torch.jit.trace(n.forward, example_forward_input) # Trace a module (implicitly traces `forward`) and construct a # `ScriptModule` with a single `forward` method module = torch.jit.trace(n, example_forward_input)
-