torch.jit.trace

torch.jit.trace(func, example_inputs=None, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>, example_kwarg_inputs=None, _store_inputs=True)[源代码]

跟踪一个函数,并返回一个可执行文件或 ScriptFunction,该文件会通过即时编译进行优化。

追踪适用于仅操作Tensor及其列表、字典和元组的代码。

使用torch.jit.tracetorch.jit.trace_module,你可以将现有的模块或Python函数转换为TorchScript ScriptFunctionScriptModule。你需要提供示例输入,系统会运行该函数并记录所有张量上的操作。

  • 独立函数的记录结果会产生 ScriptFunction

  • 记录 nn.Module.forwardnn.Module 后,会生成 ScriptModule

此模块还包括原始模块的所有参数。

警告

追踪仅记录那些不依赖于数据(例如,不在张量中进行条件判断)且没有任何未跟踪的外部依赖项(例如,执行输入/输出或访问全局变量)的函数和模块。它只记录在给定张量上运行特定函数时的操作。因此,返回的ScriptModule将在任何输入上始终运行相同的追踪图。当你的模块预期根据输入和/或模块状态的不同而执行不同的操作集时,这会产生一些重要的影响。例如,

  • 追踪不会记录任何控制流程,如 if 语句或循环。当这种控制流程在整个模块中是恒定的时,这是可以接受的,并且通常会内联这些决策。但有时控制流实际上是模型的一部分。例如,递归网络会对输入序列(可能是动态长度)进行循环处理。

  • 在返回的ScriptModule中,那些在trainingeval模式下行为不同的操作将始终按照记录时的模式运行,无论当前ScriptModule处于哪种模式。

在这种情况下,使用追踪是不合适的,而脚本化是一个更好的选择。如果你对这样的模型进行追踪,可能会在后续调用该模型时默默地得到错误的结果。当执行可能导致生成错误跟踪的操作时,追踪器会尝试发出警告。

参数

func (callabletorch.nn.Module) – 一个 Python 函数或 torch.nn.Module,将使用 example_inputs 运行。函数的参数和返回值必须是张量或包含张量的(可能是嵌套的)元组。当传递模块时,只有 forward 方法会被运行并进行追踪(详情请参阅 torch.jit.trace)。

关键字参数
  • example_inputs (元组torch.TensorNone, 可选) – 用于在跟踪过程中传递给函数的示例输入元组。默认值为 None。此参数或 example_kwarg_inputs 中的一个必须指定。生成的跟踪可以使用不同类型的输入和形状运行,前提是所跟踪的操作支持这些类型和形状。example_inputs 也可以是一个单独的 Tensor,在这种情况下它会被自动包装成一个元组。当值为 None 时,example_kwarg_inputs 应该被指定。

  • check_trace (bool, optional) – 检查相同的输入通过跟踪代码是否产生相同的输出。默认值为True。如果你的网络包含非确定性操作,或者你确信即使检查器失败网络也是正确的,你可以禁用此选项。

  • check_inputs (list of tuples, optional) – 一个包含输入参数元组的列表,用于检查跟踪是否符合预期。每个元组等同于在 example_inputs 中指定的一组输入参数。为了获得最佳效果,请传入一组具有代表性的检查输入,这些输入涵盖了你期望网络看到的各种形状和类型。如果未指定,则使用原始的 example_inputs 进行检查。

  • check_tolerance (float, 可选) – 在检查过程中使用的浮点数比较容差。如果由于某些已知原因(如操作符融合)导致结果出现数值差异,可以使用此参数来放松检查的严格性。

  • strict (bool, 可选) – 是否以严格模式运行追踪器(默认值: True)。仅在你希望追踪器记录可变容器类型(当前为list/dict),并且确信你在问题中使用的容器是一个constant结构且不会用作控制流(如 if、for 语句)条件时,才关闭此模式。

  • example_kwarg_inputs (dict, 可选) – 此参数包含一组关键字参数,这些参数是示例输入,在跟踪函数时将传递给该函数。默认值: None。此参数或example_inputs必须指定一个。字典中的键会根据被跟踪函数的参数名称进行解包。如果字典中的键与被跟踪函数的参数名不匹配,则会引发运行时异常。

返回值

如果 funcnn.Modulenn.Moduleforward 方法,trace 将返回一个包含单个 forward 方法的 ScriptModule 对象,该方法包含了追踪到的代码。返回的 ScriptModule 具有与原始 nn.Module 相同的一组子模块和参数。如果 func 是一个独立函数,则 trace 将返回 ScriptFunction

示例(跟踪一个函数):

import torch

def foo(x, y):
    return 2 * x + y

# Run `foo` with the provided inputs and record the tensor operations
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

# `traced_foo` can now be run with the TorchScript interpreter or saved
# and loaded in a Python-free environment

示例(跟踪现有模块):

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        return self.conv(x)

n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)

# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)

# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)
本页目录