TorchScript 语言参考
TorchScript 是一个静态类型的 Python 子集,可以直接编写(使用 @torch.jit.script
装饰器)或通过追踪从 Python 代码自动生成。当使用追踪时,代码会自动转换为这个子集,只记录张量上的实际操作,并执行和丢弃其他周围的 Python 代码。
当直接使用@torch.jit.script
装饰器编写TorchScript时,程序员只能使用TorchScript支持的一小部分Python功能。本节将记录这些被支持的功能,就像它们是独立语言的参考文档一样。此参考中未提及的任何Python特性都不属于TorchScript。有关PyTorch张量方法、模块和函数的完整列表,请参见内置函数。
作为Python的一个子集,任何有效的TorchScript函数也是有效的Python函数。这使得可以禁用TorchScript并使用标准的Python调试工具如pdb
来调试该函数。反之则不然:有许多有效的Python程序在TorchScript中是无效的。相反,TorchScript专注于表示PyTorch中的神经网络模型所需的特定Python特性。
类型
TorchScript与完整的Python语言之间的最大区别在于,它只支持一组有限的类型,这些类型用于表达神经网络模型所必需。具体来说,TorchScript支持以下内容:
类型 |
描述 |
---|---|
|
任何数据类型、维度或后端的 PyTorch 张量 |
|
包含子类型 |
|
一个布尔类型的值 |
|
一个整数标量 |
|
一个标量浮点数 |
|
字符串 |
|
一个所有成员都是类型 |
|
这个值要么是 None,要么是类型 |
|
一个键类型为 |
|
TorchScript 类 |
|
|
|
一种 |
|
其中一个子类型是 |
与 Python 不同,TorchScript 函数中的每个变量都必须有一个固定的静态类型。这样可以更容易地优化 TorchScript 函数。
示例(类型错误)
import torch @torch.jit.script def an_error(x): if x: r = torch.rand(1) else: r = 4 return r
Traceback (most recent call last): ... RuntimeError: ... Type mismatch: r is set to type Tensor in the true branch and type int in the false branch: @torch.jit.script def an_error(x): if x: ~~~~~ r = torch.rand(1) ~~~~~~~~~~~~~~~~~ else: ~~~~~ r = 4 ~~~~~ <--- HERE return r and was used here: else: r = 4 return r ~ <--- HERE...
不支持的类型构造
TorchScript 不支持 typing
模块中的所有特性和类型。其中一些基础特性未来不太可能被添加,而其他一些则可能会根据用户需求在未来被考虑加入。
来自 typing
模块的这些类型和特性在 TorchScript 中不可用。
项目 |
描述 |
---|---|
|
|
未实现 |
|
未实现 |
|
未实现 |
|
未实现 |
|
未实现 |
|
这支持模块属性的类属性注解,但不支持函数 |
|
TorchScript 不支持 |
|
|
|
类型别名 |
未实现 |
名义子类型和结构子类型 |
名义类型化正在开发中,但结构类型化还未实现。
|
新类型 |
不太可能被实施 |
泛型 |
不太可能被实施 |
除了本文档中明确列出的功能之外,typing
模块中的其他功能均不受支持。
默认类型
默认情况下,TorchScript 函数的所有参数都假定为张量类型。若要指定某个参数为其他类型,则可以使用上述列出的 MyPy 风格的类型注解。
import torch @torch.jit.script def foo(x, tup): # type: (int, Tuple[Tensor, Tensor]) -> Tensor t0, t1 = tup return t0 + t1 + x print(foo(3, (torch.rand(3), torch.rand(3))))
注意
也可以使用typing
模块中的 Python 3 类型提示来标注类型。
import torch from typing import Tuple @torch.jit.script def foo(x: int, tup: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: t0, t1 = tup return t0 + t1 + x print(foo(3, (torch.rand(3), torch.rand(3))))
空列表被视为List[Tensor]
,空字典为Dict[str, Tensor]
。若需创建其他类型的空列表或字典,请使用Python 3 类型提示。
示例(Python 3 类型注解):
import torch import torch.nn as nn from typing import Dict, List, Tuple class EmptyDataStructures(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x: torch.Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]: # This annotates the list to be a `List[Tuple[int, float]]` my_list: List[Tuple[int, float]] = [] for i in range(10): my_list.append((i, x.item())) my_dict: Dict[str, int] = {} return my_list, my_dict x = torch.jit.script(EmptyDataStructures())
可选类型的细化
TorchScript 在 if 语句的条件中或在 assert
中对类型为 Optional[T]
的变量与 None
进行比较时,会细化该变量的类型。编译器可以推断多个通过 and
、or
和 not
结合的 None
检查的结果。对于未显式书写的 if 语句的 else 块,也会进行类型细化。
对 None
的检查必须直接放在 if 语句的条件中;将 None
检查的结果赋值给一个变量,然后在 if 语句中使用该变量不会细化这些变量的类型。只有局部变量会被细化,例如 self.x
这样的属性不会被细化,并且需要将其赋值给一个局部变量才能进行类型细化。
示例(通过参数和局部变量来细化类型):
import torch import torch.nn as nn from typing import Optional class M(nn.Module): z: Optional[int] def __init__(self, z): super().__init__() # If `z` is None, its type cannot be inferred, so it must # be specified (above) self.z = z def forward(self, x, y, z): # type: (Optional[int], Optional[int], Optional[int]) -> int if x is None: x = 1 x = x + 1 # Refinement for an attribute by assigning it to a local z = self.z if y is not None and z is not None: x = y + z # Refinement via an `assert` assert z is not None x += z return x module = torch.jit.script(M(2)) module = torch.jit.script(M(None))
TorchScript 类
警告
TorchScript 类支持处于试验阶段,目前最适用于简单的记录类型(类似带有方法的 NamedTuple
)。
如果你用 @torch.jit.script
注解 Python 类,它们就可以在 TorchScript 中使用,就像声明一个 TorchScript 函数一样:
@torch.jit.script class Foo: def __init__(self, x, y): self.x = x def aug_add_x(self, inc): self.x += inc
这个子集受到限制:
-
所有的函数都必须是有效的TorchScript函数,包括
__init__()
。 -
类必须是新式类,因为我们会使用
__new__()
方法并通过 pybind11 来创建它们。 -
TorchScript 类是静态类型的。成员变量只能在
__init__()
方法中通过将值赋给 self 来声明。例如,在
__init__()
方法之外将self
赋值给其他变量:@torch.jit.script class Foo: def assign_x(self): self.x = torch.rand(2, 3)
将导致:
RuntimeError: Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?: def assign_x(self): self.x = torch.rand(2, 3) ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
-
类的主体中只能包含方法定义,不能有其他表达式。
-
不支持继承及其他任何形式的多态策略,但可以通过从
object
继承来指定新式类。
定义类之后,它可以在 TorchScript 和 Python 中像其他任何 TorchScript 类型一样互相交换使用。
# Declare a TorchScript class @torch.jit.script class Pair: def __init__(self, first, second): self.first = first self.second = second @torch.jit.script def sum_pair(p): # type: (Pair) -> Tensor return p.first + p.second p = Pair(torch.rand(2, 3), torch.rand(2, 3)) print(sum_pair(p))
TorchScript 枚举
Python 枚举可以直接在 TorchScript 中使用,无需添加额外的注解或代码。
from enum import Enum class Color(Enum): RED = 1 GREEN = 2 @torch.jit.script def enum_fn(x: Color, y: Color) -> bool: if x == Color.RED: return True return x == y
定义枚举后,可以在 TorchScript 和 Python 中像使用其他任何 TorchScript 类型一样互换使用该枚举。枚举值的类型必须是 int
、float
或 str
,并且所有值的类型必须一致;不支持不同类型混用的枚举值。
命名元组
由collections.namedtuple
生成的类型可以用于TorchScript。
import torch import collections Point = collections.namedtuple('Point', ['x', 'y']) @torch.jit.script def total(point): # type: (Point) -> Tensor return point.x + point.y p = Point(x=torch.rand(3), y=torch.rand(3)) print(total(p))
可迭代对象
某些函数(例如,zip
和 enumerate
)只能用于可迭代类型。在 TorchScript 中,可迭代类型包括 Tensor
、列表、元组、字典、字符串、torch.nn.ModuleList
和 torch.nn.ModuleDict
。
表达式
以下 Python 表达式被支持。
字面量
True False None 'string literals' "string literals" 3 # interpreted as int 3.4 # interpreted as a float
列表构建
一个空列表被假设为类型 List[Tensor]
。其他列表字面量的类型根据其成员的类型来确定。详情请参阅默认类型。
[3, 4] [] [torch.rand(3), torch.rand(4)]
元组构造
(3, 4) (3,)
字典构建
一个空字典被假设为类型 Dict[str, Tensor]
。其他字典的类型根据其成员的类型推导得出。更多详情请参见默认类型。
{'hello': 3} {} {'a': torch.rand(3), 'b': torch.rand(4)}
算术运算符
a + b a - b a * b a / b a ^ b a @ b
比较运算符
a == b a != b a < b a > b a <= b a >= b
逻辑运算符
a and b a or b not b
下标与切片
t[0] t[-1] t[0:2] t[1:] t[:1] t[:] t[0, 1] t[0, 1:2] t[0, :1] t[-1, 1:, 0] t[1:, -1, 0] t[i:j, i]
函数调用
调用内置函数
torch.rand(3, dtype=torch.int)
调用其他脚本中的函数:
import torch @torch.jit.script def foo(x): return x + 1 @torch.jit.script def bar(x): return foo(x)
方法调用
调用像张量这样的内置类型的函数方法:x.mm(y)
在模块中,方法必须先编译才能被调用。TorchScript 编译器会递归地编译它遇到的方法,在编译其他方法时进行处理。默认情况下,从 forward
方法开始编译。任何由 forward
调用的方法都会被编译,并且这些方法调用的其他方法也会被递归地编译。若要从不同于 forward
的方法开始编译,请使用 @torch.jit.export
装饰器(forward
方法默认隐式标记为 @torch.jit.export
)。
直接调用子模块(如 self.resnet(input)
)与调用其 forward
方法(如 self.resnet.forward(input)
)是等价的。
import torch import torch.nn as nn import torchvision class MyModule(nn.Module): def __init__(self): super().__init__() means = torch.tensor([103.939, 116.779, 123.68]) self.means = torch.nn.Parameter(means.resize_(1, 3, 1, 1)) resnet = torchvision.models.resnet18() self.resnet = torch.jit.trace(resnet, torch.rand(1, 3, 224, 224)) def helper(self, input): return self.resnet(input - self.means) def forward(self, input): return self.helper(input) # Since nothing in the model calls `top_level_method`, the compiler # must be explicitly told to compile this method @torch.jit.export def top_level_method(self, input): return self.other_helper(input) def other_helper(self, input): return input + 10 # `my_script_module` will have the compiled methods `forward`, `helper`, # `top_level_method`, and `other_helper` my_script_module = torch.jit.script(MyModule())
三元表达式
x if x > y else y
类型转换
float(ten) int(3.5) bool(ten) str(2)``
访问模块参数
self.my_parameter self.my_submodule.my_parameter
语句
简单任务
a = b a += b # short-hand for a = a + b, does not operate in-place on a a -= b
模式匹配赋值
a, b = tuple_or_list a, b, *c = a_tuple
多重赋值
a = b, c = tup
打印语句
print("the result of an add:", a + b)
条件语句
if a < 4: r = -a elif a < 3: r = a + a else: r = 3 * a
除布尔值外,浮点数、整数和张量也可用于条件语句,并会被隐式转换为布尔值。
while 循环
a = 0 while a < 4: print(a) a += 1
使用range的for循环
x = 0 for i in range(10): x *= i
遍历元组的for循环
这些操作展开循环,为元组中的每个成员生成相应的代码块,并且每个代码块都必须通过类型检查。
tup = (3, torch.rand(4)) for x in tup: print(x)
遍历常量nn.ModuleList的for循环
要在编译方法内使用nn.ModuleList
,必须将其标记为常量,具体操作是将属性名称添加到类型的__constants__
列表中。对于遍历nn.ModuleList
的循环,在编译时会展开循环体,每个模块列表中的成员都会被单独处理。
class SubModule(torch.nn.Module): def __init__(self): super().__init__() self.weight = nn.Parameter(torch.randn(2)) def forward(self, input): return self.weight + input class MyModule(torch.nn.Module): __constants__ = ['mods'] def __init__(self): super().__init__() self.mods = torch.nn.ModuleList([SubModule() for i in range(10)]) def forward(self, v): for module in self.mods: v = module(v) return v m = torch.jit.script(MyModule())
断点和继续
for i in range(5): if i == 1: continue if i == 3: break print(i)
返回
return a, b
变量分辨率
TorchScript 支持 Python 的变量解析(即作用域)规则的一个子集。局部变量的行为与 Python 中相同,但有一个限制:变量在函数的所有路径上必须保持相同的类型。如果一个变量在一个 if 语句的不同分支中具有不同的类型,则在 if 语句之后使用该变量会导致错误。
类似地,如果一个变量仅在函数中的某些路径上被定义,则不允许使用该变量。
示例:
@torch.jit.script def foo(x): if x < 0: y = 4 print(y)
Traceback (most recent call last): ... RuntimeError: ... y is not defined in the false branch... @torch.jit.script... def foo(x): if x < 0: ~~~~~~~~~ y = 4 ~~~~~ <--- HERE print(y) and was used here: if x < 0: y = 4 print(y) ~ <--- HERE...
在函数定义时,非局部变量会在编译期间被解析为Python值。然后,这些Python值会根据使用Python值中的规则转换为TorchScript值。
Python值的使用
为了使编写TorchScript更加方便,我们允许脚本代码引用周围作用域中的Python值。例如,每当引用 torch
时,实际上当函数声明时,TorchScript编译器将其解析为 torch
Python模块。这些Python值不是TorchScript的一部分,而是会在编译时被转换成TorchScript支持的基本类型。这种转换取决于在编译发生时引用的Python值的动态类型。本节描述了访问TorchScript中Python值所遵循的规则。
函数
TorchScript 可以调用 Python 函数,这在逐步将模型转换为 TorchScript 时非常有用。你可以逐函数地将模型移至 TorchScript,并保留对 Python 函数的调用。这样可以在每一步中增量检查模型的正确性。
- torch.jit.is_scripting()[源代码]
-
一个在编译时返回True、其他情况下返回False的函数。特别是在使用@unused装饰器时非常有用,可以保留模型中尚未与TorchScript兼容的代码。.. testcode:
import torch @torch.jit.unused def unsupported_linear_op(x): return x def linear(x): if torch.jit.is_scripting(): return torch.linear(x) else: return unsupported_linear_op(x)
- 返回类型
- torch.jit.is_tracing()[源代码]
-
返回布尔值。
在使用
torch.jit.trace
追踪代码期间,如果调用了某个函数,则返回True
,否则返回False
。
查询Python模块中的属性
TorchScript 可以在模块上查找属性,例如通过这种方式访问像 torch.add
这样的内置函数。这样,TorchScript 就能调用其他模块中定义的函数了。
Python中的常量
TorchScript 还支持使用在 Python 中定义的常量。这些常量可以用于将超参数硬编码到函数中,或定义通用常量。有兩種方式可以指定一個 Python 值應該被視為常量。
-
作为模块属性查找到的值被视作常量:
import math import torch @torch.jit.script def fn(): return math.pi
-
可以使用
Final[T]
注解将ScriptModule的属性标记为常量
import torch import torch.nn as nn class Foo(nn.Module): # `Final` from the `typing_extensions` module can also be used a : torch.jit.Final[int] def __init__(self): super().__init__() self.a = 1 + 4 def forward(self, input): return self.a + input f = torch.jit.script(Foo())
支持的Python常量类型包括
-
int
-
float
-
bool
-
torch.device
-
torch.layout
-
torch.dtype
-
包含支持类型的元组
-
torch.nn.ModuleList
可以用于 TorchScript 中的循环
模块属性
torch.nn.Parameter
包装器和 register_buffer
可用于将张量分配给模块。其他可以被编译的模块赋值的内容,如果其类型可推断,则会被添加到编译后的模块中。所有在 TorchScript 中可用的类型都可以作为模块属性使用。张量属性与缓冲区在语义上相同。空列表、字典和 None
值的类型无法推断,必须通过PEP 526 风格类注释来指定。如果一个类型的值不能被推断且没有显式标注,则它不会作为属性添加到生成的 ScriptModule
中。
示例:
from typing import List, Dict class Foo(nn.Module): # `words` is initialized as an empty list, so its type must be specified words: List[str] # The type could potentially be inferred if `a_dict` (below) was not # empty, but this annotation ensures `some_dict` will be made into the # proper type some_dict: Dict[str, int] def __init__(self, a_dict): super().__init__() self.words = [] self.some_dict = a_dict # `int`s can be inferred self.my_int = 10 def forward(self, input): # type: (str) -> int self.words.append(input) return self.some_dict[input] + self.my_int f = torch.jit.script(Foo({'hi': 2}))