torch.jit.trace_module
- torch.jit.trace_module(mod, inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>, example_inputs_is_kwarg=False, _store_inputs=True)[源代码]
-
跟踪一个模块,并返回一个可执行的
ScriptModule
,该模块会通过即时编译来进行优化。当一个模块被传递给
torch.jit.trace
时,只有forward
方法会被运行并进行追踪。使用trace_module
,你可以指定一个包含方法名和示例输入的字典来进行追踪(参见下面的inputs
参数)。关于追踪的更多内容,请参见
torch.jit.trace
。- 参数
-
-
mod (torch.nn.Module) – 一个包含在
inputs
中指定的方法的torch.nn.Module
。给定的方法将作为单个ScriptModule的一部分进行编译。 -
inputs (dict) – 一个包含样本输入的字典,这些输入由
mod
中的方法名索引。在跟踪过程中,输入将被传递给与键对应的方法。{'forward': example_forward_input, 'method2': example_method2_input}
-
- 关键字参数
-
-
check_trace (
bool
, optional) – 检查相同的输入通过跟踪代码是否产生相同的输出。默认值为True
。如果你的网络包含非确定性操作,或者你确信即使检查器失败网络也是正确的,你可以禁用此选项。 -
check_inputs (list of dicts, optional) – 一个包含输入参数字典的列表,用于检查跟踪是否符合预期。每个字典相当于在
inputs
中指定的一组输入参数。为了获得最佳结果,请传入一组代表网络可能遇到的不同形状和类型的输入进行检查。如果没有指定,则使用原始的inputs
进行检查。 -
check_tolerance (float, 可选) – 在检查过程中使用的浮点数比较容差。如果由于某些已知原因(如操作符融合)导致结果出现数值差异,可以使用此参数来放松检查的严格性。
-
example_inputs_is_kwarg (
bool
, 可选) – 表示示例输入是否为关键字参数的包。默认值:False
。
-
- 返回值
-
一个包含单个
forward
方法(内含追踪代码)的ScriptModule
对象。如果func
是一个torch.nn.Module
,则返回的ScriptModule
将具有与func
相同的一组子模块和参数。
示例(具有多个方法的模块的追踪):
import torch import torch.nn as nn class Net(nn.Module): def __init__(self) -> None: super().__init__() self.conv = nn.Conv2d(1, 1, 3) def forward(self, x): return self.conv(x) def weighted_kernel_sum(self, weight): return weight * self.conv.weight n = Net() example_weight = torch.rand(1, 1, 3, 3) example_forward_input = torch.rand(1, 1, 3, 3) # Trace a specific method and construct `ScriptModule` with # a single `forward` method module = torch.jit.trace(n.forward, example_forward_input) # Trace a module (implicitly traces `forward`) and construct a # `ScriptModule` with a single `forward` method module = torch.jit.trace(n, example_forward_input) # Trace specific methods on a module (specified in `inputs`), constructs # a `ScriptModule` with `forward` and `weighted_kernel_sum` methods inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight} module = torch.jit.trace_module(n, inputs)