TorchScript 语言参考

TorchScript 是一个静态类型的 Python 子集,可以直接编写(使用 @torch.jit.script 装饰器)或通过追踪从 Python 代码自动生成。当使用追踪时,代码会自动转换为这个子集,只记录张量上的实际操作,并执行和丢弃其他周围的 Python 代码。

当直接使用@torch.jit.script装饰器编写TorchScript时,程序员只能使用TorchScript支持的一小部分Python功能。本节将记录这些被支持的功能,就像它们是独立语言的参考文档一样。此参考中未提及的任何Python特性都不属于TorchScript。有关PyTorch张量方法、模块和函数的完整列表,请参见内置函数

作为Python的一个子集,任何有效的TorchScript函数也是有效的Python函数。这使得可以禁用TorchScript并使用标准的Python调试工具如pdb来调试该函数。反之则不然:有许多有效的Python程序在TorchScript中是无效的。相反,TorchScript专注于表示PyTorch中的神经网络模型所需的特定Python特性。

类型

TorchScript与完整的Python语言之间的最大区别在于,它只支持一组有限的类型,这些类型用于表达神经网络模型所必需。具体来说,TorchScript支持以下内容:

类型

描述

张量

任何数据类型、维度或后端的 PyTorch 张量

Tuple[T0, T1, ..., TN]

包含子类型 T0T1 等的元组(例如 Tuple[Tensor, Tensor]

bool

一个布尔类型的值

int

一个整数标量

float

一个标量浮点数

str

字符串

List[T]

一个所有成员都是类型 T 的列表

Optional[T]

这个值要么是 None,要么是类型 T

Dict[K, V]

一个键类型为K(可以是strintfloat)和值类型为V的字典。

T

TorchScript 类

E

一个TorchScript 枚举类型

NamedTuple[T0, T1, ...]

一种 collections.namedtuple 元组类型

Union[T0, T1, ...]

其中一个子类型是 T0T1 等等。

与 Python 不同,TorchScript 函数中的每个变量都必须有一个固定的静态类型。这样可以更容易地优化 TorchScript 函数。

示例(类型错误)

import torch

@torch.jit.script
def an_error(x):
    if x:
        r = torch.rand(1)
    else:
        r = 4
    return r
Traceback (most recent call last):
  ...
RuntimeError: ...

Type mismatch: r is set to type Tensor in the true branch and type int in the false branch:
@torch.jit.script
def an_error(x):
    if x:
    ~~~~~
        r = torch.rand(1)
        ~~~~~~~~~~~~~~~~~
    else:
    ~~~~~
        r = 4
        ~~~~~ <--- HERE
    return r
and was used here:
    else:
        r = 4
    return r
           ~ <--- HERE...

不支持的类型构造

TorchScript 不支持 typing 模块中的所有特性和类型。其中一些基础特性未来不太可能被添加,而其他一些则可能会根据用户需求在未来被考虑加入。

来自 typing 模块的这些类型和特性在 TorchScript 中不可用。

项目

描述

typing.Any

typing.Any 当前还在开发中,尚未发布

typing.NoReturn

未实现

typing.Sequence

未实现

typing.Callable

未实现

typing.Literal

未实现

typing.ClassVar

未实现

typing.Final

这支持模块属性的类属性注解,但不支持函数

typing.AnyStr

TorchScript 不支持 bytes,所以不会使用这种类型。

typing.overload

typing.overload 当前还在开发中,尚未发布

类型别名

未实现

名义子类型和结构子类型

名义类型化正在开发中,但结构类型化还未实现。

新类型

不太可能被实施

泛型

不太可能被实施

除了本文档中明确列出的功能之外,typing 模块中的其他功能均不受支持。

默认类型

默认情况下,TorchScript 函数的所有参数都假定为张量类型。若要指定某个参数为其他类型,则可以使用上述列出的 MyPy 风格的类型注解。

import torch

@torch.jit.script
def foo(x, tup):
    # type: (int, Tuple[Tensor, Tensor]) -> Tensor
    t0, t1 = tup
    return t0 + t1 + x

print(foo(3, (torch.rand(3), torch.rand(3))))

注意

也可以使用typing模块中的 Python 3 类型提示来标注类型。

import torch
from typing import Tuple

@torch.jit.script
def foo(x: int, tup: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
    t0, t1 = tup
    return t0 + t1 + x

print(foo(3, (torch.rand(3), torch.rand(3))))

空列表被视为List[Tensor],空字典为Dict[str, Tensor]。若需创建其他类型的空列表或字典,请使用Python 3 类型提示

示例(Python 3 类型注解):

import torch
import torch.nn as nn
from typing import Dict, List, Tuple

class EmptyDataStructures(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]:
        # This annotates the list to be a `List[Tuple[int, float]]`
        my_list: List[Tuple[int, float]] = []
        for i in range(10):
            my_list.append((i, x.item()))

        my_dict: Dict[str, int] = {}
        return my_list, my_dict

x = torch.jit.script(EmptyDataStructures())

可选类型的细化

TorchScript 在 if 语句的条件中或在 assert 中对类型为 Optional[T] 的变量与 None 进行比较时,会细化该变量的类型。编译器可以推断多个通过 andornot 结合的 None 检查的结果。对于未显式书写的 if 语句的 else 块,也会进行类型细化。

None 的检查必须直接放在 if 语句的条件中;将 None 检查的结果赋值给一个变量,然后在 if 语句中使用该变量不会细化这些变量的类型。只有局部变量会被细化,例如 self.x 这样的属性不会被细化,并且需要将其赋值给一个局部变量才能进行类型细化。

示例(通过参数和局部变量来细化类型):

import torch
import torch.nn as nn
from typing import Optional

class M(nn.Module):
    z: Optional[int]

    def __init__(self, z):
        super().__init__()
        # If `z` is None, its type cannot be inferred, so it must
        # be specified (above)
        self.z = z

    def forward(self, x, y, z):
        # type: (Optional[int], Optional[int], Optional[int]) -> int
        if x is None:
            x = 1
            x = x + 1

        # Refinement for an attribute by assigning it to a local
        z = self.z
        if y is not None and z is not None:
            x = y + z

        # Refinement via an `assert`
        assert z is not None
        x += z
        return x

module = torch.jit.script(M(2))
module = torch.jit.script(M(None))

TorchScript 类

警告

TorchScript 类支持处于试验阶段,目前最适用于简单的记录类型(类似带有方法的 NamedTuple)。

如果你用 @torch.jit.script 注解 Python 类,它们就可以在 TorchScript 中使用,就像声明一个 TorchScript 函数一样:

@torch.jit.script
class Foo:
  def __init__(self, x, y):
    self.x = x

  def aug_add_x(self, inc):
    self.x += inc

这个子集受到限制:

  • 所有的函数都必须是有效的TorchScript函数,包括 __init__()

  • 类必须是新式类,因为我们会使用 __new__() 方法并通过 pybind11 来创建它们。

  • TorchScript 类是静态类型的。成员变量只能在 __init__() 方法中通过将值赋给 self 来声明。

    例如,在__init__()方法之外将self赋值给其他变量:

    @torch.jit.script
    class Foo:
      def assign_x(self):
        self.x = torch.rand(2, 3)
    

    将导致:

    RuntimeError:
    Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?:
    def assign_x(self):
      self.x = torch.rand(2, 3)
      ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    
  • 类的主体中只能包含方法定义,不能有其他表达式。

  • 不支持继承及其他任何形式的多态策略,但可以通过从object继承来指定新式类。

定义类之后,它可以在 TorchScript 和 Python 中像其他任何 TorchScript 类型一样互相交换使用。

# Declare a TorchScript class
@torch.jit.script
class Pair:
  def __init__(self, first, second):
    self.first = first
    self.second = second

@torch.jit.script
def sum_pair(p):
  # type: (Pair) -> Tensor
  return p.first + p.second

p = Pair(torch.rand(2, 3), torch.rand(2, 3))
print(sum_pair(p))

TorchScript 枚举

Python 枚举可以直接在 TorchScript 中使用,无需添加额外的注解或代码。

from enum import Enum


class Color(Enum):
    RED = 1
    GREEN = 2

@torch.jit.script
def enum_fn(x: Color, y: Color) -> bool:
    if x == Color.RED:
        return True

    return x == y

定义枚举后,可以在 TorchScript 和 Python 中像使用其他任何 TorchScript 类型一样互换使用该枚举。枚举值的类型必须是 intfloatstr,并且所有值的类型必须一致;不支持不同类型混用的枚举值。

命名元组

collections.namedtuple生成的类型可以用于TorchScript。

import torch
import collections

Point = collections.namedtuple('Point', ['x', 'y'])

@torch.jit.script
def total(point):
    # type: (Point) -> Tensor
    return point.x + point.y

p = Point(x=torch.rand(3), y=torch.rand(3))
print(total(p))

可迭代对象

某些函数(例如,zipenumerate)只能用于可迭代类型。在 TorchScript 中,可迭代类型包括 Tensor、列表、元组、字典、字符串、torch.nn.ModuleListtorch.nn.ModuleDict

表达式

以下 Python 表达式被支持。

字面量

True
False
None
'string literals'
"string literals"
3  # interpreted as int
3.4  # interpreted as a float

列表构建

一个空列表被假设为类型 List[Tensor]。其他列表字面量的类型根据其成员的类型来确定。详情请参阅默认类型

[3, 4]
[]
[torch.rand(3), torch.rand(4)]

元组构造

(3, 4)
(3,)

字典构建

一个空字典被假设为类型 Dict[str, Tensor]。其他字典的类型根据其成员的类型推导得出。更多详情请参见默认类型

{'hello': 3}
{}
{'a': torch.rand(3), 'b': torch.rand(4)}

变量

关于变量的解析方法,请参见变量解析

my_variable_name

算术运算符

a + b
a - b
a * b
a / b
a ^ b
a @ b

比较运算符

a == b
a != b
a < b
a > b
a <= b
a >= b

逻辑运算符

a and b
a or b
not b

下标与切片

t[0]
t[-1]
t[0:2]
t[1:]
t[:1]
t[:]
t[0, 1]
t[0, 1:2]
t[0, :1]
t[-1, 1:, 0]
t[1:, -1, 0]
t[i:j, i]

函数调用

调用内置函数

torch.rand(3, dtype=torch.int)

调用其他脚本中的函数:

import torch

@torch.jit.script
def foo(x):
    return x + 1

@torch.jit.script
def bar(x):
    return foo(x)

方法调用

调用像张量这样的内置类型的函数方法:x.mm(y)

在模块中,方法必须先编译才能被调用。TorchScript 编译器会递归地编译它遇到的方法,在编译其他方法时进行处理。默认情况下,从 forward 方法开始编译。任何由 forward 调用的方法都会被编译,并且这些方法调用的其他方法也会被递归地编译。若要从不同于 forward 的方法开始编译,请使用 @torch.jit.export 装饰器(forward 方法默认隐式标记为 @torch.jit.export)。

直接调用子模块(如 self.resnet(input))与调用其 forward 方法(如 self.resnet.forward(input))是等价的。

import torch
import torch.nn as nn
import torchvision

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        means = torch.tensor([103.939, 116.779, 123.68])
        self.means = torch.nn.Parameter(means.resize_(1, 3, 1, 1))
        resnet = torchvision.models.resnet18()
        self.resnet = torch.jit.trace(resnet, torch.rand(1, 3, 224, 224))

    def helper(self, input):
        return self.resnet(input - self.means)

    def forward(self, input):
        return self.helper(input)

    # Since nothing in the model calls `top_level_method`, the compiler
    # must be explicitly told to compile this method
    @torch.jit.export
    def top_level_method(self, input):
        return self.other_helper(input)

    def other_helper(self, input):
        return input + 10

# `my_script_module` will have the compiled methods `forward`, `helper`,
# `top_level_method`, and `other_helper`
my_script_module = torch.jit.script(MyModule())

三元表达式

x if x > y else y

类型转换

float(ten)
int(3.5)
bool(ten)
str(2)``

访问模块参数

self.my_parameter
self.my_submodule.my_parameter

语句

TorchScript 支持以下类型的语句:

简单任务

a = b
a += b # short-hand for a = a + b, does not operate in-place on a
a -= b

模式匹配赋值

a, b = tuple_or_list
a, b, *c = a_tuple

多重赋值

a = b, c = tup

条件语句

if a < 4:
    r = -a
elif a < 3:
    r = a + a
else:
    r = 3 * a

除布尔值外,浮点数、整数和张量也可用于条件语句,并会被隐式转换为布尔值。

while 循环

a = 0
while a < 4:
    print(a)
    a += 1

使用range的for循环

x = 0
for i in range(10):
    x *= i

遍历元组的for循环

这些操作展开循环,为元组中的每个成员生成相应的代码块,并且每个代码块都必须通过类型检查。

tup = (3, torch.rand(4))
for x in tup:
    print(x)

遍历常量nn.ModuleList的for循环

要在编译方法内使用nn.ModuleList,必须将其标记为常量,具体操作是将属性名称添加到类型的__constants__列表中。对于遍历nn.ModuleList的循环,在编译时会展开循环体,每个模块列表中的成员都会被单独处理。

class SubModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(2))

    def forward(self, input):
        return self.weight + input

class MyModule(torch.nn.Module):
    __constants__ = ['mods']

    def __init__(self):
        super().__init__()
        self.mods = torch.nn.ModuleList([SubModule() for i in range(10)])

    def forward(self, v):
        for module in self.mods:
            v = module(v)
        return v


m = torch.jit.script(MyModule())

断点和继续

for i in range(5):
    if i == 1:
        continue
    if i == 3:
        break
    print(i)

返回

return a, b

变量分辨率

TorchScript 支持 Python 的变量解析(即作用域)规则的一个子集。局部变量的行为与 Python 中相同,但有一个限制:变量在函数的所有路径上必须保持相同的类型。如果一个变量在一个 if 语句的不同分支中具有不同的类型,则在 if 语句之后使用该变量会导致错误。

类似地,如果一个变量仅在函数中的某些路径上被定义,则不允许使用该变量。

示例:

@torch.jit.script
def foo(x):
    if x < 0:
        y = 4
    print(y)
Traceback (most recent call last):
  ...
RuntimeError: ...

y is not defined in the false branch...
@torch.jit.script...
def foo(x):
    if x < 0:
    ~~~~~~~~~
        y = 4
        ~~~~~ <--- HERE
    print(y)
and was used here:
    if x < 0:
        y = 4
    print(y)
          ~ <--- HERE...

在函数定义时,非局部变量会在编译期间被解析为Python值。然后,这些Python值会根据使用Python值中的规则转换为TorchScript值。

Python值的使用

为了使编写TorchScript更加方便,我们允许脚本代码引用周围作用域中的Python值。例如,每当引用 torch 时,实际上当函数声明时,TorchScript编译器将其解析为 torch Python模块。这些Python值不是TorchScript的一部分,而是会在编译时被转换成TorchScript支持的基本类型。这种转换取决于在编译发生时引用的Python值的动态类型。本节描述了访问TorchScript中Python值所遵循的规则。

函数

TorchScript 可以调用 Python 函数,这在逐步将模型转换为 TorchScript 时非常有用。你可以逐函数地将模型移至 TorchScript,并保留对 Python 函数的调用。这样可以在每一步中增量检查模型的正确性。

torch.jit.is_scripting()[源代码]

一个在编译时返回True、其他情况下返回False的函数。特别是在使用@unused装饰器时非常有用,可以保留模型中尚未与TorchScript兼容的代码。.. testcode:

import torch

@torch.jit.unused
def unsupported_linear_op(x):
    return x

def linear(x):
    if torch.jit.is_scripting():
        return torch.linear(x)
    else:
        return unsupported_linear_op(x)
返回类型

bool

torch.jit.is_tracing()[源代码]

返回布尔值。

在使用 torch.jit.trace 追踪代码期间,如果调用了某个函数,则返回 True,否则返回 False

查询Python模块中的属性

TorchScript 可以在模块上查找属性,例如通过这种方式访问像 torch.add 这样的内置函数。这样,TorchScript 就能调用其他模块中定义的函数了。

Python中的常量

TorchScript 还支持使用在 Python 中定义的常量。这些常量可以用于将超参数硬编码到函数中,或定义通用常量。有兩種方式可以指定一個 Python 值應該被視為常量。

  1. 作为模块属性查找到的值被视作常量:

import math
import torch

@torch.jit.script
def fn():
    return math.pi
  1. 可以使用Final[T]注解将ScriptModule的属性标记为常量

import torch
import torch.nn as nn

class Foo(nn.Module):
    # `Final` from the `typing_extensions` module can also be used
    a : torch.jit.Final[int]

    def __init__(self):
        super().__init__()
        self.a = 1 + 4

    def forward(self, input):
        return self.a + input

f = torch.jit.script(Foo())

支持的Python常量类型包括

  • int

  • float

  • bool

  • torch.device

  • torch.layout

  • torch.dtype

  • 包含支持类型的元组

  • torch.nn.ModuleList 可以用于 TorchScript 中的循环

模块属性

torch.nn.Parameter 包装器和 register_buffer 可用于将张量分配给模块。其他可以被编译的模块赋值的内容,如果其类型可推断,则会被添加到编译后的模块中。所有在 TorchScript 中可用的类型都可以作为模块属性使用。张量属性与缓冲区在语义上相同。空列表、字典和 None 值的类型无法推断,必须通过PEP 526 风格类注释来指定。如果一个类型的值不能被推断且没有显式标注,则它不会作为属性添加到生成的 ScriptModule 中。

示例:

from typing import List, Dict

class Foo(nn.Module):
    # `words` is initialized as an empty list, so its type must be specified
    words: List[str]

    # The type could potentially be inferred if `a_dict` (below) was not
    # empty, but this annotation ensures `some_dict` will be made into the
    # proper type
    some_dict: Dict[str, int]

    def __init__(self, a_dict):
        super().__init__()
        self.words = []
        self.some_dict = a_dict

        # `int`s can be inferred
        self.my_int = 10

    def forward(self, input):
        # type: (str) -> int
        self.words.append(input)
        return self.some_dict[input] + self.my_int

f = torch.jit.script(Foo({'hi': 2}))
本页目录