TorchScript

TorchScript 是一种从 PyTorch 代码生成可序列化和可优化模型的方法。任何 TorchScript 程序都可以从一个 Python 进程中保存,并在一个无需 Python 的环境中加载。

我们提供了工具,可以逐步将模型从纯Python程序转换为独立于Python运行的TorchScript程序(例如在独立的C++程序中)。这使得可以在Python中使用熟悉的工具训练PyTorch模型,并通过TorchScript导出该模型到生产环境,在这种环境下,由于性能和多线程的原因,Python程序可能处于劣势。

了解 TorchScript 的入门指南,请参阅TorchScript 入门教程

要查看将 PyTorch 模型转换为 TorchScript 并在 C++ 中运行的端到端示例,请参阅在 C++ 中加载 PyTorch 模型教程。

创建TorchScript代码

script

编写函数脚本。

trace

跟踪一个函数,并返回可执行文件或ScriptFunction,后者将通过即时编译进行优化。

script_if_tracing

在跟踪过程中首次调用时编译fn

trace_module

跟踪一个模块,并返回一个可执行的ScriptModule,该模块会通过即时编译来进行优化。

fork

创建一个异步任务来执行 func,并引用该任务执行结果的值。

wait

强制完成一个torch.jit.Future[T]异步任务,并返回该任务的结果。

ScriptModule

C++ torch::jit::Module 的一个包装器,包含了方法、属性和参数。

ScriptFunction

功能上与ScriptModule相同,但代表一个单独的函数,并且没有属性或参数。

freeze

将 ScriptModule 冻结,并将其子模块和属性作为常量内联。

optimize_for_inference

进行一系列优化操作,以便将模型优化用于推理目的。

enable_onednn_fusion

基于参数enabled来启用或禁用onednn JIT融合。

onednn_fusion_enabled

返回onednn JIT融合是否已启用。

set_fusion_strategy

设定在融合过程中可以出现的特化类型的数量。

strict_fusion

如果没有在推理过程中融合所有节点,或者在训练过程中没有进行符号微分,则报告错误。

save

保存此模块的离线版本,以便在单独的进程中使用。

load

使用 torch.jit.save 之前保存的 ScriptModuleScriptFunction 进行加载。

ignore

这个装饰器告诉编译器忽略某个函数或方法,并将其保留在 Python 函数形式。

unused

这个装饰器告诉编译器,某个函数或方法应被忽略,并替换为抛出异常。

interface

使用装饰器来标记不同类型的类或模块。

isinstance

在TorchScript中提供容器类型的具体化。

Attribute

此方法是一个传递函数,返回value,主要用于告知TorchScript编译器,左侧表达式是一个具有类型type的类实例属性。

annotate

用于在TorchScript编译器中为the_value指定类型。

结合追踪和 scripting 对于 "scripting",如果需要中文表达可以改为“脚本”,但根据上下文,保留英文原词可能更合适。所以这里保持不变。

在许多情况下,将模型转换为TorchScript时,使用追踪或脚本化方法会更容易。追踪和脚本化可以结合使用,以满足模型某个部分的特定需求。

脚本化函数可以调用被追踪的函数。这在你需要在一个简单的前馈模型中使用控制流时特别有用。例如,序列到序列模型的束搜索通常会以脚本形式编写,但可以调用通过追踪生成的编码器模块。

示例(在脚本中调用被跟踪的函数):

import torch

def foo(x, y):
    return 2 * x + y

traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

@torch.jit.script
def bar(x):
    return traced_foo(x, x)

可追溯函数可以调用脚本函数。当模型中有一小部分需要控制流逻辑,而大部分模型只是简单的前馈网络时,这一点特别有用。在这种情况下,脚本函数内的控制流会得到正确的保留。

示例(在跟踪的函数中调用脚本函数):

import torch

@torch.jit.script
def foo(x, y):
    if x.max() > y.max():
        r = x
    else:
        r = y
    return r


def bar(x, y, z):
    return foo(x, y) + z

traced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3)))

这种组合同样适用于nn.Module,可以在其中使用追踪来生成一个子模块,并且这个子模块可以从脚本模块的方法中进行调用。

示例(使用带有追踪的模块):

import torch
import torchvision

class MyScriptModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68])
                                        .resize_(1, 3, 1, 1))
        self.resnet = torch.jit.trace(torchvision.models.resnet18(),
                                      torch.rand(1, 3, 224, 224))

    def forward(self, input):
        return self.resnet(input - self.means)

my_script_module = torch.jit.script(MyScriptModule())

TorchScript语言

TorchScript 是一个静态类型的 Python 子集,因此许多 Python 特性可以直接用于 TorchScript。更多细节请参阅 TorchScript 语言参考

内置函数和模块

TorchScript 支持大多数 PyTorch 函数和许多 Python 内置函数。详情请参阅 TorchScript 内置函数 获取支持函数的完整列表。

PyTorch 函数与模块

TorchScript 支持 PyTorch 提供的部分张量和神经网络功能。这包括 Tensor 对象上的大多数方法,torch 命名空间中的所有函数,torch.nn.functional 中的所有函数以及来自 torch.nn 的大部分模块。

参见TorchScript 不支持的 PyTorch 构造,了解不支持的 PyTorch 函数和模块列表。

Python 函数与模块

Python 的许多内置函数在TorchScript中得到了支持。此外,math模块也得到了支持(详情请参阅math 模块),但不支持其他Python模块,无论是内置的还是第三方的。

Python 语言参考手册对比

有关支持的 Python 特性完整列表,请参见 Python 语言参考

调试

禁用JIT调试

PyTorch JIT

设置环境变量 PYTORCH_JIT=0 可以禁用所有 TorchScript 的脚本和追踪功能。如果你的 TorchScript 模型中存在难以调试的问题,可以使用此标志强制所有内容使用原生 Python 运行。由于该标志禁用了 TorchScript(包括脚本化和跟踪),你可以利用像 pdb 这样的工具来调试模型代码。例如:

@torch.jit.script
def scripted_fn(x : torch.Tensor):
    for i in range(12):
        x = x + x
    return x

def fn(x):
    x = torch.neg(x)
    import pdb; pdb.set_trace()
    return scripted_fn(x)

traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),))
traced_fn(torch.rand(3, 4))

使用pdb调试此脚本可以正常工作,除非我们调用了@torch.jit.script函数。我们可以全局禁用JIT,这样就可以将@torch.jit.script作为普通的Python函数调用而不进行编译。如果上述脚本名为disable_jit_example.py,我们可以像这样调用它:

$ PYTORCH_JIT=0 python disable_jit_example.py

我们就可以像调用普通 Python 函数一样进入 @torch.jit.script 函数。要为特定函数禁用 TorchScript 编译器,请参见 @torch.jit.ignore

检查代码

TorchScript 为所有 ScriptModule 实例提供了一个代码美化打印功能,该功能将脚本方法的代码解释为有效的 Python 语法。例如:

@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv

print(foo.code)

一个具有单个 forward 方法的 ScriptModule 将包含一个属性 code,你可以使用它来查看该模块的代码。如果 ScriptModule 包含多个方法,则需要在具体的方法上访问其 .code 属性,而不是整个模块。例如,可以通过访问 .foo.code 来查看名为 foo 的方法的代码。上面的例子会产生以下输出:

def foo(len: int) -> Tensor:
    rv = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
    rv0 = rv
    for i in range(len):
        if torch.lt(i, 10):
            rv1 = torch.sub(rv0, 1., 1)
        else:
            rv1 = torch.add(rv0, 1., 1)
        rv0 = rv1
    return rv0

这是TorchScript编译的forward方法的代码。你可以用它来确保TorchScript(无论是通过追踪还是脚本方式)正确地捕获了你的模型代码。

解释图表

TorchScript 还有一个比代码美化打印器更低层次的表示形式,即 IR 图。

TorchScript 使用静态单赋值(SSA)中间表示(IR)来描述计算过程。该格式中的指令由 ATen(PyTorch 的 C++ 后端)操作符和其他基本操作符组成,包括循环和条件判断的操作符。例如:

@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv

print(foo.graph)

graph检查代码 部分中提到的 forward 方法查找规则相同。

上述示例脚本生成了如下图形:

graph(%len.1 : int):
  %24 : int = prim::Constant[value=1]()
  %17 : bool = prim::Constant[value=1]() # test.py:10:5
  %12 : bool? = prim::Constant()
  %10 : Device? = prim::Constant()
  %6 : int? = prim::Constant()
  %1 : int = prim::Constant[value=3]() # test.py:9:22
  %2 : int = prim::Constant[value=4]() # test.py:9:25
  %20 : int = prim::Constant[value=10]() # test.py:11:16
  %23 : float = prim::Constant[value=1]() # test.py:12:23
  %4 : int[] = prim::ListConstruct(%1, %2)
  %rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10
  %rv : Tensor = prim::Loop(%len.1, %17, %rv.1) # test.py:10:5
    block0(%i.1 : int, %rv.14 : Tensor):
      %21 : bool = aten::lt(%i.1, %20) # test.py:11:12
      %rv.13 : Tensor = prim::If(%21) # test.py:11:9
        block0():
          %rv.3 : Tensor = aten::sub(%rv.14, %23, %24) # test.py:12:18
          -> (%rv.3)
        block1():
          %rv.6 : Tensor = aten::add(%rv.14, %23, %24) # test.py:14:18
          -> (%rv.6)
      -> (%17, %rv.13)
  return (%rv)

以以下指令为例:%rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10

  • %rv.1: Tensor 表示我们将输出分配给一个名为 rv.1 的唯一值,该值是 Tensor 类型,并且我们不知道其具体形状。

  • aten::zeros 是一个操作符(等同于 torch.zeros),输入列表 (%4, %6, %6, %10, %12) 指定了哪些值应作为输入传递。内置函数如 aten::zeros 的详细信息可以在内置函数部分找到。

  • # test.py:9:10 表示生成该指令的原始源文件的位置。具体来说,这是名为test.py的文件中的第9行,第10个字符位置。

请注意,操作符还可以包含相关的blocks,例如prim::Loopprim::If。在图形输出中,这些操作符的格式会与源代码形式一致,便于调试。

可以通过以下方式检查图,确认ScriptModule描述的计算是否正确,无论是自动化的还是手动的方式。

追踪器

追踪边缘案例

有些情况下,给定的Python函数或模块的跟踪信息可能无法准确反映其底层代码的情况。这些情况可能包括:

  • 基于输入(如张量形状)的控制流跟踪

  • 张量视图的原地操作跟踪(例如,赋值语句左侧的索引操作)

请注意,未来可能能够追踪到这些情况。

自动追踪检查

自动捕获跟踪中许多错误的一种方法是使用 torch.jit.trace() API 上的 check_inputs 参数。check_inputs 接受一个输入元组列表,用于重新执行计算并验证结果。例如:

def loop_in_traced_fn(x):
    result = x[0]
    for i in range(x.size(0)):
        result = result * x[i]
    return result

inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]

traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs)

提供以下诊断信息:

ERROR: Graphs differed across invocations!
Graph diff:

            graph(%x : Tensor) {
            %1 : int = prim::Constant[value=0]()
            %2 : int = prim::Constant[value=0]()
            %result.1 : Tensor = aten::select(%x, %1, %2)
            %4 : int = prim::Constant[value=0]()
            %5 : int = prim::Constant[value=0]()
            %6 : Tensor = aten::select(%x, %4, %5)
            %result.2 : Tensor = aten::mul(%result.1, %6)
            %8 : int = prim::Constant[value=0]()
            %9 : int = prim::Constant[value=1]()
            %10 : Tensor = aten::select(%x, %8, %9)
        -   %result : Tensor = aten::mul(%result.2, %10)
        +   %result.3 : Tensor = aten::mul(%result.2, %10)
        ?          ++
            %12 : int = prim::Constant[value=0]()
            %13 : int = prim::Constant[value=2]()
            %14 : Tensor = aten::select(%x, %12, %13)
        +   %result : Tensor = aten::mul(%result.3, %14)
        +   %16 : int = prim::Constant[value=0]()
        +   %17 : int = prim::Constant[value=3]()
        +   %18 : Tensor = aten::select(%x, %16, %17)
        -   %15 : Tensor = aten::mul(%result, %14)
        ?     ^                                 ^
        +   %19 : Tensor = aten::mul(%result, %18)
        ?     ^                                 ^
        -   return (%15);
        ?             ^
        +   return (%19);
        ?             ^
            }

这条消息表明,在我们第一次追踪和使用check_inputs 进行第二次追踪时,计算结果出现了差异。确实,loop_in_traced_fn 体内的循环依赖于输入 x 的形状,因此当尝试具有不同形状的另一个 x 时,追踪的结果会有所不同。

在这种情况下,可以使用torch.jit.script()来捕获数据依赖的控制流:

def fn(x):
    result = x[0]
    for i in range(x.size(0)):
        result = result * x[i]
    return result

inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]

scripted_fn = torch.jit.script(fn)
print(scripted_fn.graph)
#print(str(scripted_fn.graph).strip())

for input_tuple in [inputs] + check_inputs:
    torch.testing.assert_close(fn(*input_tuple), scripted_fn(*input_tuple))

产生的结果是:

graph(%x : Tensor) {
    %5 : bool = prim::Constant[value=1]()
    %1 : int = prim::Constant[value=0]()
    %result.1 : Tensor = aten::select(%x, %1, %1)
    %4 : int = aten::size(%x, %1)
    %result : Tensor = prim::Loop(%4, %5, %result.1)
    block0(%i : int, %7 : Tensor) {
        %10 : Tensor = aten::select(%x, %1, %i)
        %result.2 : Tensor = aten::mul(%7, %10)
        -> (%5, %result.2)
    }
    return (%result);
}

跟踪器警告

追踪器会为追溯计算中的某些问题模式生成警告。例如,考虑一个函数的跟踪记录,该函数包含对张量切片(视图为)的就地赋值操作:

def fill_row_zero(x):
    x[0] = torch.rand(*x.shape[1:2])
    return x

traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)

生成几个警告以及一个仅返回输入的简单图表。

fill_row_zero.py:4: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
    x[0] = torch.rand(*x.shape[1:2])
fill_row_zero.py:6: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:
Not within tolerance rtol=1e-05 atol=1e-05 at input[0, 1] (0.09115803241729736 vs. 0.6782537698745728) and 3 other locations (33.00%)
    traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
graph(%0 : Float(3, 4)) {
    return (%0);
}

我们可以通过修改代码,避免使用就地更新,改为使用torch.cat方法在外置位置构建结果张量。

def fill_row_zero(x):
    x = torch.cat((torch.rand(1, *x.shape[1:2]), x[1:2]), dim=0)
    return x

traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)

常见问题

问题:我在GPU上训练模型,在CPU上进行推理,有什么最佳实践吗?

首先将模型从 GPU 转换到 CPU,然后保存它,如下所示:

cpu_model = gpu_model.cpu()
sample_input_cpu = sample_input_gpu.cpu()
traced_cpu = torch.jit.trace(cpu_model, sample_input_cpu)
torch.jit.save(traced_cpu, "cpu.pt")

traced_gpu = torch.jit.trace(gpu_model, sample_input_gpu)
torch.jit.save(traced_gpu, "gpu.pt")

# ... later, when using the model:

if use_gpu:
  model = torch.jit.load("gpu.pt")
else:
  model = torch.jit.load("cpu.pt")

model(input)

这样推荐是因为追踪器可能观察到在特定设备上创建张量的过程,因此对已加载模型进行类型转换可能会导致意想不到的结果。在保存模型前进行类型转换可以确保追踪器获取正确的设备信息。

如何在ScriptModule上存储属性?

假设我们有这样的一个模型:

import torch

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.x = 2

    def forward(self):
        return self.x

m = torch.jit.script(Model())

如果实例化Model,由于编译器不知道x,将会导致编译错误。有4种方法可以告知编译器ScriptModule的属性:

1. nn.Parameter - 在 nn.Parameter 中包裹的值将在 nn.Module 中像平常那样工作。

2. register_buffer - 在 register_buffer 中包装的值将像在 nn.Module 上一样工作。这相当于一个类型为 Tensor 的属性(参见 4)。

3. 常量 - 将类成员标注为 Final,或将其添加到类定义级别的名为 __constants__ 的列表中,这样可以将这些名称标记为常量。常量会直接保存在模型的代码中。详情请参见内置常量

4. 属性 - 值为支持的类型可以作为可变属性添加。大多数类型的值可以自动推断,但也有一些类型需要手动指定,请参阅模块属性以获取详细信息。

问题:我想跟踪模块的方法,但总是遇到这个错误:

RuntimeError: 无法将需要梯度的 Tensor 作为常量插入。 考虑将其设为参数或输入,或者断开梯度。

这个错误通常表示你在跟踪一个方法时使用了模块的参数,但传递的是模块的方法而不是模块实例(例如 my_module_instance.forwardmy_module_instance)。

  • 调用带有模块方法的 trace 会将模块参数(这些参数可能需要梯度)捕获为 常量

  • 另一方面,使用模块实例(例如my_module)调用trace会创建一个新模块,并正确复制参数到新模块中,这样在需要时就可以累积梯度。

要追踪模块中的特定方法,请参见 torch.jit.trace_module

已知问题

如果你在使用Sequential与TorchScript,即使已经进行了其他注解,Sequential子模块的一些输入也可能被错误地推断为Tensor。标准的解决方案是子类化nn.Sequential并重新声明带有正确类型的forward方法。

附录

迁移至PyTorch 1.2递归脚本API

本节详细介绍了 PyTorch 1.2 中 TorchScript 的更新内容。如果你是首次使用 TorchScript,可以跳过此节。PyTorch 1.2 主要对 TorchScript API 进行了两项更改。

1. torch.jit.script 现在将尝试递归编译遇到的函数、方法和类。一旦你调用 torch.jit.script,编译就是默认开启的,而不是默认关闭的。

2. torch.jit.script(nn_module_instance) 现已成为创建ScriptModule 的首选方法,替代了继承自torch.jit.ScriptModule的方式。这些变化共同提供了一个更简单、更易于使用的 API,用于将你的 nn.Modules 转换为ScriptModules,并准备好在非 Python 环境中进行优化和执行。

新的使用方法如下所示:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

my_model = Model()
my_scripted_model = torch.jit.script(my_model)
  • 模块的forward默认会被编译。从forward中调用的方法会按照在forward中的使用顺序进行延迟编译。

  • 要编译在 forward 之外且不通过 forward 调用的方法,请添加 @torch.jit.export 注解。

  • 要阻止编译器编译某个方法,可以添加@torch.jit.ignore 或者 @torch.jit.unused。使用@ignore会保留该方法的定义。

  • 将方法作为对Python的调用,并使用@unused 替换它来抛出异常。 @ignored 不能被导出;而@unused 可以。

  • 大多数属性类型可以被自动推断出来,因此不需要使用torch.jit.Attribute。对于空容器类型,请采用PEP 526风格的类注释来标注它们的类型。

  • 可以使用Final类注解来标记常量,而无需将成员名称添加到__constants__中。

  • 在 Python 3 中,类型提示可以代替 torch.jit.annotate 的使用。

由于这些更改,以下项目已被视为弃用,在新代码中不应再出现:
  • 装饰器 @torch.jit.script_method

  • 继承自torch.jit.ScriptModule的类

  • torch.jit.Attribute 包装类

  • __constants__ 数组

  • torch.jit.annotate 函数

模块

警告

在 PyTorch 1.2 中,@torch.jit.ignore 注解的行为发生了变化。在 PyTorch 1.2 之前,@ignore 装饰器用于使函数或方法可以从导出的代码中调用。要恢复此功能,请使用 @torch.jit.unused()。现在 @torch.jit.ignore 等同于 @torch.jit.ignore(drop=False)。详情请参阅@torch.jit.ignore@torch.jit.unused

当传递给torch.jit.script 函数时,torch.nn.Module 的数据会被复制到一个 ScriptModule 中,并且 TorchScript 编译器会编译该模块。默认情况下,模块的 forward 方法会被编译。从 forward 调用的方法会在它们被使用时按顺序延迟编译,以及任何带有 @torch.jit.export 注解的方法。

torch.jit.export(fn)[源代码]

此装饰器表示 nn.Module 中的方法用作 ScriptModule 的入口点,并应进行编译。

forward 默认被视为入口点,因此无需使用该装饰器。从 forward 调用的函数和方法会在编译器遇到时被编译,所以这些函数和方法同样不需要该装饰器。

示例(在方法上使用 @torch.jit.export):

import torch
import torch.nn as nn

class MyModule(nn.Module):
    def implicitly_compiled_method(self, x):
        return x + 99

    # `forward` is implicitly decorated with `@torch.jit.export`,
    # so adding it here would have no effect
    def forward(self, x):
        return x + 10

    @torch.jit.export
    def another_forward(self, x):
        # When the compiler sees this call, it will compile
        # `implicitly_compiled_method`
        return self.implicitly_compiled_method(x)

    def unused_method(self, x):
        return x - 20

# `m` will contain compiled methods:
#     `forward`
#     `another_forward`
#     `implicitly_compiled_method`
# `unused_method` will not be compiled since it was not called from
# any compiled methods and wasn't decorated with `@torch.jit.export`
m = torch.jit.script(MyModule())

函数

函数基本不变,如有需要,可以使用@torch.jit.ignoretorch.jit.unused进行装饰。

# Same behavior as pre-PyTorch 1.2
@torch.jit.script
def some_fn():
    return 2

# Marks a function as ignored, if nothing
# ever calls it then this has no effect
@torch.jit.ignore
def some_fn2():
    return 2

# As with ignore, if nothing calls it then it has no effect.
# If it is called in script it is replaced with an exception.
@torch.jit.unused
def some_fn3():
  import pdb; pdb.set_trace()
  return 4

# Doesn't do anything, this function is already
# the main entry point
@torch.jit.export
def some_fn4():
    return 2

TorchScript 类

警告

TorchScript 类支持处于试验阶段,目前最适用于简单的记录类型(类似带有方法的 NamedTuple)。

用户定义的TorchScript 类中的所有内容默认都会被导出。如果需要,可以使用@torch.jit.ignore装饰器来忽略某些函数。

属性

TorchScript 编译器需要知道模块属性的类型。大多数类型的值可以从成员变量中推断出来,但空列表和字典无法通过其值来推断类型,必须使用 PEP 526 风格 的类注释进行标注。如果未明确标注且无法推断出类型,则该属性不会被添加到结果 ScriptModule 中。

旧API:

from typing import Dict
import torch

class MyModule(torch.jit.ScriptModule):
    def __init__(self):
        super().__init__()
        self.my_dict = torch.jit.Attribute({}, Dict[str, int])
        self.my_int = torch.jit.Attribute(20, int)

m = MyModule()

新API:

from typing import Dict

class MyModule(torch.nn.Module):
    my_dict: Dict[str, int]

    def __init__(self):
        super().__init__()
        # This type cannot be inferred and must be specified
        self.my_dict = {}

        # The attribute type here is inferred to be `int`
        self.my_int = 20

    def forward(self):
        pass

m = torch.jit.script(MyModule())

常量

Final 类型构造函数可以用于将成员标记为常量。如果没有被标记为常量,这些成员将会被复制到生成的ScriptModule 中作为属性。使用 Final 可以在值已知且固定的情况下提供优化机会,并增加类型安全性。

旧API:

class MyModule(torch.jit.ScriptModule):
    __constants__ = ['my_constant']

    def __init__(self):
        super().__init__()
        self.my_constant = 2

    def forward(self):
        pass
m = MyModule()

新API:

from typing import Final

class MyModule(torch.nn.Module):

    my_constant: Final[int]

    def __init__(self):
        super().__init__()
        self.my_constant = 2

    def forward(self):
        pass

m = torch.jit.script(MyModule())

变量

容器假设具有 Tensor 类型,并且是非可选的(有关更多信息,请参见默认类型)。此前,使用 torch.jit.annotate 来告知 TorchScript 编译器应使用的类型。现在支持 Python 3 风格的类型提示。

import torch
from typing import Dict, Optional

@torch.jit.script
def make_dict(flag: bool):
    x: Dict[str, int] = {}
    x['hi'] = 2
    b: Optional[int] = None
    if flag:
        b = 2
    return x, b

融合后端

有几个融合后端可以用来优化 TorchScript 的执行。CPU 上的默认融合器是 NNC,它可以为 CPU 和 GPU 执行融合操作。GPU 上的默认融合器是 NVFuser,它支持更广泛的算子,并且已经展示了具有改进吞吐量的生成内核。有关使用和调试的更多详细信息,请参阅NVFuser 文档

本页目录