TorchScript 语言参考
此参考手册描述了TorchScript语言的语法和核心语义。TorchScript是Python的一个静态类型子集。本文件解释了TorchScript支持的Python特性,以及它与常规Python的不同之处。任何未在此手册中提及的Python功能都不属于TorchScript。TorchScript专注于表示PyTorch神经网络模型所需的Python特性。
术语
本文档采用了以下术语:
模式 |
注释 |
---|---|
|
表示给定的符号是如何被定义的。 |
|
代表实际的关键词和分隔符,这些是语法的一部分。 |
|
表示要么是A,要么是B。 |
|
表示分类分组。 |
|
表示选项。 |
|
表示一个正则表达式,其中术语A至少出现一次。 |
|
表示在正则表达式中,项A可以重复零次或多次。 |
类型系统
TorchScript 是一个静态类型的 Python 子集。它与完整 Python 语言之间的最大区别在于,TorchScript 只支持一组有限的类型,这些类型足以表示神经网络模型。
TorchScript 类型
TorchScript 类型系统包含 TSType
和 TSModuleType
,如下定义。
TSAllType ::= TSType | TSModuleType TSType ::= TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType
TSType
代表大多数可以组合的 TorchScript 类型,并且可以在 TorchScript 类型注释中使用。具体来说,TSType
包括以下几种类型中的任意一种:
-
元类型,例如
Any
-
基本类型,例如:
int
,float
,和str
-
结构类型,例如
Optional[int]
或List[MyClass]
-
名义类型(如 Python 类),例如:
MyClass
(用户定义),torch.tensor
(内置)
TSModuleType
表示 torch.nn.Module
及其子类。它与 TSType
不同,因为它的类型模式部分从对象实例推断而来,部分从类定义中得出。因此,TSModuleType
的实例可能不会遵循相同的静态类型模式。TSModuleType
不能用作 TorchScript 类型注解,并且出于类型安全的考虑,也不能与 TSType
组合。
元类型
元类型非常抽象,更像是一种类型约束而非具体类型。当前,TorchScript 定义了一种名为 Any
的元类型,用于表示任意的 TorchScript 类型。
Any
类型
Any
类型表示任何 TorchScript 类型。由于 Any
没有类型约束,因此不会对其进行类型检查。它可以绑定到任何 Python 或 TorchScript 数据类型(例如,int
、TorchScript tuple
或未编写的任意 Python 类)。
TSMetaType ::= "Any"
具体来说:
-
Any
是 Python 中 typing 模块的一个类型。因此,要使用Any
类型,需要从typing
导入它(例如,from typing import Any
)。 -
因为
Any
可以代表任何 TorchScript 类型,所以可以在Any
上使用的操作符集合是有限的。
Any
类型支持的运算符
-
将数据赋值给类型为
Any
的变量。 -
将绑定到类型为
Any
的参数或返回值。 -
x 是
,x 不是
,其中x
是Any
类型。 -
isinstance(x, Type)
,其中x
是Any
类型。 -
任何类型的数据都是可以打印的。
-
类型为
List[Any]
的数据可能具有可排序性,如果该数据是一个由相同类型T
的值组成的列表,并且类型T
支持比较运算符。
与Python相比
Any
是 TorchScript 类型系统中约束最少的类型。从这个意义上说,它与 Python 中的 Object
类非常相似。然而,Any
只支持一小部分由 Object
支持的运算符和方法。
设计笔记
当我们为 PyTorch 模块编写脚本时,可能会遇到一些在执行过程中不涉及的数据。然而,这些数据仍然需要通过类型模式进行描述。不仅描述未使用数据的静态类型(在脚本上下文中)很繁琐,还可能导致不必要的编译错误。Any
被引入来描述那些在编译时不需要精确静态类型的數據的类型。
示例 1
此示例展示了如何使用Any
让元组参数中的第二个元素可以是任意类型。这是因为在涉及x[1]
的任何计算中,并不需要知道其确切类型。
import torch from typing import Tuple from typing import Any @torch.jit.export def inc_first_element(x: Tuple[int, Any]): return (x[0]+1, x[1]) m = torch.jit.script(inc_first_element) print(m((1,2.0))) print(m((1,(100,200))))
上述示例的输出如下:
(2, 2.0) (2, (100, 200))
元组的第二个元素是Any
类型,因此可以绑定到多种类型的值。例如,(1, 2.0)
将浮点数2.0
绑定为Any
类型,即Tuple[int, Any]
;而(1, (100, 200))
在第二次调用时将元组(100, 200)
绑定到Any
。
示例 2
此示例展示了如何使用isinstance
来动态检查标记为Any
类型的数据的类型。
import torch from typing import Any def f(a:Any): print(a) return (isinstance(a, torch.Tensor)) ones = torch.ones([2]) m = torch.jit.script(f) print(m(ones))
上述示例的输出如下:
1 1 [ CPUFloatType{2} ] True
基本类型
TorchScript的基本类型代表单一值的类型,并且使用一个预先定义的类型名称。
TSPrimitiveType ::= "int" | "float" | "double" | "complex" | "bool" | "str" | "None"
结构类型
结构类型是指没有用户自定义名称(不同于命名类型)的结构性定义的类型,例如 Future[int]
。结构类型可以与任何其他 TSType
进行组合。
TSStructuralType ::= TSTuple | TSNamedTuple | TSList | TSDict | TSOptional | TSUnion | TSFuture | TSRRef | TSAwait TSTuple ::= "Tuple" "[" (TSType ",")* TSType "]" TSNamedTuple ::= "namedtuple" "(" (TSType ",")* TSType ")" TSList ::= "List" "[" TSType "]" TSOptional ::= "Optional" "[" TSType "]" TSUnion ::= "Union" "[" (TSType ",")* TSType "]" TSFuture ::= "Future" "[" TSType "]" TSRRef ::= "RRef" "[" TSType "]" TSAwait ::= "Await" "[" TSType "]" TSDict ::= "Dict" "[" KeyType "," TSType "]" KeyType ::= "str" | "int" | "float" | "bool" | TensorType | "Any"
具体来说:
-
Tuple
,List
,Optional
,Union
,Future
, 和Dict
是在typing
模块中定义的 Python 类型类名称。要使用这些类型名称,需要从typing
导入它们(例如,from typing import Tuple
)。 -
namedtuple
代表 Python 中的collections.namedtuple
或typing.NamedTuple
类。 -
Future
和RRef
分别对应 Python 中的torch.futures
和torch.distributed.rpc
类。 -
Await
是 Python 中torch._awaits._Await
类的实例。
与Python相比
除了可以与TorchScript类型结合使用之外,这些TorchScript结构化类型通常还支持与其对应的Python类型相同的运算符和方法的公共子集。
示例 1
此示例使用 typing.NamedTuple
语法来定义一个元组:
import torch from typing import NamedTuple from typing import Tuple class MyTuple(NamedTuple): first: int second: int def inc(x: MyTuple) -> Tuple[int, int]: return (x.first+1, x.second+1) t = MyTuple(first=1, second=2) scripted_inc = torch.jit.script(inc) print("TorchScript:", scripted_inc(t))
上述示例的输出如下:
TorchScript: (2, 3)
示例 2
此示例使用collections.namedtuple
语法定义一个元组:
import torch from typing import NamedTuple from typing import Tuple from collections import namedtuple _AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('first', int), ('second', int)]) _UnannotatedNamedTuple = namedtuple('_NamedTupleAnnotated', ['first', 'second']) def inc(x: _AnnotatedNamedTuple) -> Tuple[int, int]: return (x.first+1, x.second+1) m = torch.jit.script(inc) print(inc(_UnannotatedNamedTuple(1,2)))
上述示例的输出如下:
(2, 3)
示例 3
此示例展示了在使用结构类型时常见的错误:未从typing
模块导入复合类型类。
import torch # ERROR: Tuple not recognized because not imported from typing @torch.jit.export def inc(x: Tuple[int, int]): return (x[0]+1, x[1]+1) m = torch.jit.script(inc) print(m((1,2)))
运行上述代码会生成以下脚本错误:
File "test-tuple.py", line 5, in <module> def inc(x: Tuple[int, int]): NameError: name 'Tuple' is not defined
解决方案是在代码的开头添加以下行:from typing import Tuple
。
名义类型
TorchScript 的名义类型是 Python 类。这些类型之所以称为名义类型,是因为它们用自定义名称声明,并且通过类名来进行比较。名义类可以进一步分类为以下几种:
TSNominalType ::= TSBuiltinClasses | TSCustomClass | TSEnum
其中,TSCustomClass
和 TSEnum
必须能被编译成 TorchScript 中间表示(IR)。这是由类型检查器强制要求的。
内置类
内置的名义类型是指那些其语义被直接集成到TorchScript系统中的Python类,比如张量类型。TorchScript为这些内置的名义类型定义了具体的语义,并且通常只实现该类在Python中定义的方法或属性的一个子集。
TSBuiltinClass ::= TSTensor | "torch.device" | "torch.Stream" | "torch.dtype" | "torch.nn.ModuleList" | "torch.nn.ModuleDict" | ... TSTensor ::= "torch.Tensor" | "common.SubTensor" | "common.SubWithTorchFunction" | "torch.nn.parameter.Parameter" | and subclasses of torch.Tensor
关于 torch.nn.ModuleList 和 torch.nn.ModuleDict 的特别说明
虽然 torch.nn.ModuleList
和 torch.nn.ModuleDict
在 Python 中被定义为列表和字典,但它们在 TorchScript 中的表现更像是元组。
-
在 TorchScript 中,
torch.nn.ModuleList
和torch.nn.ModuleDict
的实例是不可变的。 -
迭代
torch.nn.ModuleList
或torch.nn.ModuleDict
的代码会被完全展开,使得torch.nn.ModuleList
中的元素或torch.nn.ModuleDict
中的键可以是torch.nn.Module
的不同子类。
示例
以下示例展示了如何使用一些内置的Torchscript类(torch.*
):
import torch @torch.jit.script class A: def __init__(self): self.x = torch.rand(3) def f(self, y: torch.device): return self.x.to(device=y) def g(): a = A() return a.f(torch.device("cpu")) script_g = torch.jit.script(g) print(script_g.graph)
自定义类
与内置类不同,自定义类的语义由用户定义,并且整个类定义必须能编译成TorchScript IR,并遵循TorchScript的类型检查规则。
TSClassDef ::= [ "@torch.jit.script" ] "class" ClassName [ "(object)" ] ":" MethodDefinition | [ "@torch.jit.ignore" ] | [ "@torch.jit.unused" ] MethodDefinition
具体来说:
-
类必须是新式类。Python 3 只支持新式类。在 Python 2.x 中,通过从 object 继承来定义一个新式类。
-
实例数据属性是静态类型的,实例属性必须在
__init__()
方法中的赋值语句里进行声明。 -
不支持方法重载(即,你不可以定义多个同名的方法)。
-
MethodDefinition
必须能编译成 TorchScript IR,并遵循 TorchScript 的类型检查规则:所有方法都必须是有效的 TorchScript 函数,而类属性定义则必须是有效的 TorchScript 语句。 -
torch.jit.ignore
和torch.jit.unused
可以用来标记那些不能完全转换为 TorchScript 的方法或函数,或者需要被编译器忽略的函数。
与Python相比
与Python中的自定义类相比,TorchScript的自定义类功能相对有限。TorchScript自定义类的特点是:
-
不支持类级别的属性。
-
除了一种情况除外,即实现接口类型或对象的子类化,否则不支持任何形式的子类化。
-
不支持方法过载。
-
必须在
__init__()
中初始化所有实例属性,因为TorchScript通过在__init__()
中推断属性类型来构建类的静态结构。 -
必须只包含符合TorchScript类型检查规则且能编译成TorchScript中间表示(IR)的方法。
示例 1
如果 Python 类使用了 @torch.jit.script
注解,它们就可以在 TorchScript 中使用,类似于声明一个 TorchScript 函数的方式:
@torch.jit.script class MyClass: def __init__(self, x: int): self.x = x def inc(self, val: int): self.x += val
示例 2
A TorchScript 自定义类类型必须通过在 __init__()
方法中的赋值来声明所有实例属性。如果一个实例属性未在 __init__()
中定义,但在其他方法中被访问,则该类无法编译为 TorchScript 类,如以下示例所示:
import torch @torch.jit.script class foo: def __init__(self): self.y = 1 # ERROR: self.x is not defined in __init__ def assign_x(self): self.x = torch.rand(2, 3)
该类将无法编译,并会显示以下错误:
RuntimeError: Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?: def assign_x(self): self.x = torch.rand(2, 3) ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
示例 3
在这个例子中,TorchScript 自定义类定义了一个类变量“name”,这是不允许的。
import torch @torch.jit.script class MyClass(object): name = "MyClass" def __init__(self, x: int): self.x = x def fn(a: MyClass): return a.name
它会导致以下编译时错误:
RuntimeError: '__torch__.MyClass' object has no attribute or method 'name'. Did you forget to initialize an attribute in __init__()?: File "test-class2.py", line 10 def fn(a: MyClass): return a.name ~~~~~~ <--- HERE
枚举类型
类似于自定义类,枚举类型的语义由用户定义,并且整个类定义必须能编译成TorchScript IR,并遵守TorchScript的类型检查规则。
TSEnumDef ::= "class" Identifier "(enum.Enum | TSEnumType)" ":" ( MemberIdentifier "=" Value )+ ( MethodDefinition )*
具体来说:
-
值必须是类型为
int
、float
或str
的 TorchScript 字面量,并且该值的类型必须与指定的 TorchScript 类型一致。 -
TSEnumType
是 TorchScript 中的一个枚举类型的名称。类似于 Python 枚举,TorchScript 允许受限的Enum
子类化,即只有在不定义任何成员的情况下才允许枚举子类化。
与Python相比
-
TorchScript 只支持
enum.Enum
。它不支持其他变体,例如enum.IntEnum
、enum.Flag
、enum.IntFlag
和enum.auto
。 -
TorchScript 枚举成员的值必须是相同类型的,并且只能是
int
、float
或str
类型,而 Python 枚举成员可以是任意类型。 -
TorchScript 忽略包含方法的枚举。
示例 1
以下示例将类 Color
定义为 Enum
类型:
import torch from enum import Enum class Color(Enum): RED = 1 GREEN = 2 def enum_fn(x: Color, y: Color) -> bool: if x == Color.RED: return True return x == y m = torch.jit.script(enum_fn) print("Eager: ", enum_fn(Color.RED, Color.GREEN)) print("TorchScript: ", m(Color.RED, Color.GREEN))
示例 2
以下示例展示了受限枚举子类化的场景:由于BaseColor
没有定义任何成员,所以它可以被Color
继承。
import torch from enum import Enum class BaseColor(Enum): def foo(self): pass class Color(BaseColor): RED = 1 GREEN = 2 def enum_fn(x: Color, y: Color) -> bool: if x == Color.RED: return True return x == y m = torch.jit.script(enum_fn) print("TorchScript: ", m(Color.RED, Color.GREEN)) print("Eager: ", enum_fn(Color.RED, Color.GREEN))
TorchScript 模块类
TSModuleType
是一种特殊类类型,它根据在 TorchScript 外部创建的对象实例进行推断。该类型的名称与对象实例的 Python 类相同。__init__()
方法不是 TorchScript 的方法,因此无需遵循 TorchScript 的类型检查规则。
模块实例类的类型模式直接从一个在TorchScript作用域之外创建的实例对象构建,而不是像自定义类那样通过__init__()
方法推断。这意味着相同类型的两个实例对象可能会遵循不同的类型模式。
从这个意义上说,TSModuleType
并不是一个真正的静态类型。因此,为了保证类型安全,TSModuleType
不能用作 TorchScript 类型注解,也不能与 TSType
进行组合。
模块实例类
TorchScript 模块类型表示用户定义的 PyTorch 模块实例的类型模式。在对 PyTorch 模块进行脚本化时,模块对象总是在外部环境中创建(即作为参数传递给 forward
方法)。Python 中的模块类被视为模块实例类,因此 Python 模块类中的 __init__()
方法不受 TorchScript 类型检查规则的约束。
TSModuleType ::= "class" Identifier "(torch.nn.Module)" ":" ClassBodyDefinition
具体来说:
-
forward()
和其他使用@torch.jit.export
装饰的方法必须能被编译成 TorchScript IR 并遵循 TorchScript 的类型检查规则。
与自定义类不同,只有模块类型中的@torch.jit.export
装饰的方法和forward
方法需要是可编译的。特别需要注意的是,__init__()
不被视为 TorchScript 方法。因此,在 TorchScript 的作用域内无法调用模块类型的构造函数。相反,TorchScript 模块对象总是在外部构建,并传递给 torch.jit.script(ModuleObj)
。
示例 1
此示例展示了模块类型的一些特点:
-
在调用
torch.jit.script
之前,TestModule
实例是在 TorchScript 的作用域之外创建的。 -
__init__()
不被视为 TorchScript 方法,因此无需进行注解,并且可以包含任意的 Python 代码。此外,在实例类中的__init__()
方法不能在 TorchScript 中调用。因为TestModule
实例是在 Python 中创建的,所以在本示例中,TestModule(2.0)
和TestModule(2)
创建了两个具有不同类型数据属性的实例。对于TestModule(2.0)
,self.x
的类型是float
,而self.y
的类型是int
。 -
TorchScript 自动编译通过
@torch.jit.export
或forward()
方法标注的方法中调用的其他方法(例如,mul()
)。 -
TorchScript程序的入口点可以是模块类型的
forward()
方法、注解为torch.jit.script
的函数,或者是注解为torch.jit.export
的方法。
import torch class TestModule(torch.nn.Module): def __init__(self, v): super().__init__() self.x = v def forward(self, inc: int): return self.x + inc m = torch.jit.script(TestModule(1)) print(f"First instance: {m(3)}") m = torch.jit.script(TestModule(torch.ones([5]))) print(f"Second instance: {m(3)}")
上述示例的输出如下:
First instance: 4 Second instance: tensor([4., 4., 4., 4., 4.])
示例 2
以下示例展示了模块类型的错误用法。具体来说,该示例在TorchScript的作用域内调用了TestModule
的构造函数。
import torch class TestModule(torch.nn.Module): def __init__(self, v): super().__init__() self.x = v def forward(self, x: int): return self.x + x class MyModel: def __init__(self, v: int): self.val = v @torch.jit.export def doSomething(self, val: int) -> int: # error: should not invoke the constructor of module type myModel = TestModule(self.val) return myModel(val) # m = torch.jit.script(MyModel(2)) # Results in below RuntimeError # RuntimeError: Could not get name of python class object
类型注解
由于 TorchScript 是静态类型的,程序员需要在 TorchScript 代码的关键位置标注类型,确保每个局部变量或实例数据属性都有静态类型,每个函数和方法都有静态类型的签名。
何时使用类型注解
通常,类型注解仅在静态类型无法自动推断的地方需要(例如,方法或函数的参数或返回类型)。局部变量和数据属性的类型通常可以从赋值语句中自动推断出来。有时候,推断出的类型可能过于严格,例如 x
通过赋值 x = None
被推断为 NoneType
,而实际上 x
是作为 Optional
使用的。在这种情况下,需要使用类型注解来覆盖自动推断,例如 x: Optional[int] = None
。请注意,即使局部变量或数据属性的类型可以被自动推断出来,对其进行类型注解也是安全的,并且标注的类型必须与TorchScript的类型检查一致。
当一个参数、局部变量或数据属性没有类型注解,并且其类型无法自动推断时,TorchScript 会将其默认类型假设为 TensorType
、List[TensorType]
或 Dict[str, TensorType]
。
标注函数签名
因为参数可能无法从函数或方法的主体中自动推断出来,所以需要添加类型注解。如果没有注解,参数将默认为类型TensorType
。
TorchScript支持两种风格的方法和函数签名类型注解:
-
Python3风格直接在函数签名上标注类型。因此,它可以单独的参数不进行注解,默认为
TensorType
类型,也可以不指定返回类型的注解,其类型将被自动推断。
Python3Annotation ::= "def" Identifier [ "(" ParamAnnot* ")" ] [ReturnAnnot] ":" FuncOrMethodBody ParamAnnot ::= Identifier [ ":" TSType ] "," ReturnAnnot ::= "->" TSType
注意,在使用Python3风格时,类型self
会自动被推断出来,无需手动注解。
-
Mypy风格 在函数或方法声明的下一行通过注释形式来标注类型。因为在 Mypy 风格中参数名称不会出现在注释里,所以所有的参数都需要被标注。
MyPyAnnotation ::= "# type:" "(" ParamAnnot* ")" [ ReturnAnnot ] ParamAnnot ::= TSType "," ReturnAnnot ::= "->" TSType
示例 1
在这个示例中:
-
a
未进行注解,默认类型为TensorType
。 -
b
被标记为int
类型。 -
返回类型无需注解,会根据实际返回值的类型自动推断为
TensorType
类型。
import torch def f(a, b: int): return a+b m = torch.jit.script(f) print("TorchScript:", m(torch.ones([6]), 100))
示例 2
以下示例采用 Mypy 风格的注解。需要注意的是,即便某些参数或返回值具有默认类型,也必须对其进行注解。
import torch def f(a, b): # type: (torch.Tensor, int) → torch.Tensor return a+b m = torch.jit.script(f) print("TorchScript:", m(torch.ones([6]), 100))
标注变量和数据属性
通常,数据属性的类型(包括类和实例数据属性)以及局部变量可以从赋值语句中自动推断出来。然而,在某些情况下,如果一个变量或属性与不同类型的数据相关联(例如 None
或 TensorType
),则可能需要显式地用如 Optional[int]
或 Any
这样的更广泛的类型进行注解。
局部变量
局部变量可以按照Python3类型的标注规则进行标注,即:
LocalVarAnnotation ::= Identifier [":" TSType] "=" Expr
通常,局部变量的类型可以自动推断。但在某些情况下,你可能需要为那些可能与不同具体类型关联的局部变量标注多类型,例如 Optional[T]
和 Any
。
示例
import torch def f(a, setVal: bool): value: Optional[torch.Tensor] = None if setVal: value = a return value ones = torch.ones([6]) m = torch.jit.script(f) print("TorchScript:", m(ones, True), m(ones, False))
实例数据属性
对于ModuleType
类,实例数据属性可以根据Python3类型模块的注解规则进行标注。此外,这些属性(可选地)可以通过Final
标记为最终属性。
"class" ClassIdentifier "(torch.nn.Module):" InstanceAttrIdentifier ":" ["Final("] TSType [")"] ...
具体来说:
-
InstanceAttrIdentifier
是一个实例属性的名字。 -
Final
表示该属性不能在__init__
方法之外重新赋值,也不能在子类中被覆盖。
示例
import torch class MyModule(torch.nn.Module): offset_: int def __init__(self, offset): self.offset_ = offset ...
类型注解API
torch.jit.annotate(T, expr)
此 API 将类型 T
注解到表达式 expr
。这通常在表达式的默认类型不是程序员期望的类型时使用。例如,空列表(字典)的默认类型为 List[TensorType]
(Dict[TensorType, TensorType]
),但有时它可能被用来初始化其他类型的列表。另一个常见的用例是注解 tensor.tolist()
的返回类型。需要注意的是,它不能用于在__init__中注解模块属性的类型;应使用 torch.jit.Attribute
代替。
示例
在这个例子中,[]
通过 torch.jit.annotate
被声明为整数列表(而不是默认的 List[TensorType]
类型):
import torch from typing import List def g(l: List[int], val: int): l.append(val) return l def f(val: int): l = g(torch.jit.annotate(List[int], []), val) return l m = torch.jit.script(f) print("Eager:", f(3)) print("TorchScript:", m(3))
更多详细信息请参见torch.jit.annotate()
。
类型注解附录
TorchScript 类型系统的定义
TSAllType ::= TSType | TSModuleType TSType ::= TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType TSMetaType ::= "Any" TSPrimitiveType ::= "int" | "float" | "double" | "complex" | "bool" | "str" | "None" TSStructuralType ::= TSTuple | TSNamedTuple | TSList | TSDict | TSOptional | TSUnion | TSFuture | TSRRef | TSAwait TSTuple ::= "Tuple" "[" (TSType ",")* TSType "]" TSNamedTuple ::= "namedtuple" "(" (TSType ",")* TSType ")" TSList ::= "List" "[" TSType "]" TSOptional ::= "Optional" "[" TSType "]" TSUnion ::= "Union" "[" (TSType ",")* TSType "]" TSFuture ::= "Future" "[" TSType "]" TSRRef ::= "RRef" "[" TSType "]" TSAwait ::= "Await" "[" TSType "]" TSDict ::= "Dict" "[" KeyType "," TSType "]" KeyType ::= "str" | "int" | "float" | "bool" | TensorType | "Any" TSNominalType ::= TSBuiltinClasses | TSCustomClass | TSEnum TSBuiltinClass ::= TSTensor | "torch.device" | "torch.stream"| "torch.dtype" | "torch.nn.ModuleList" | "torch.nn.ModuleDict" | ... TSTensor ::= "torch.tensor" and subclasses
不支持的类型构造
TorchScript 不支持 Python3 typing 模块的所有功能和类型。任何未在此文档中明确说明的功能都是不受支持的。typing 模块中的以下表格总结了在 TorchScript 中不被支持或受限制支持的 typing
构造。
项目 |
描述 |
|
正在开发 |
|
不予支持 |
|
不予支持 |
|
不予支持 |
|
不予支持 |
|
支持模块属性、类属性和注解,但不支持函数。 |
|
不予支持 |
|
正在开发 |
类型别名 |
不予支持 |
名义类型 |
正在开发 |
结构化类型 |
不予支持 |
新类型 |
不予支持 |
泛型 |
不予支持 |
表达式
以下部分介绍了TorchScript支持的表达式语法,该语法基于Python语言参考中的表达式章节。
算术转换
-
具有
float
或int
数据类型的Tensor可以隐式转换为FloatType或IntType的实例。前提是该张量大小为0,未将require_grad
设置为True
,并且不需要进行窄化。 -
可以从
StringType
隐式转换为DeviceType
。 -
根据上述两个要点中的隐式转换规则,可以将
TupleType
的实例转换为具有适当包含类型的ListType
实例。
显式转换可以使用内置函数float
、int
、bool
和str
来调用。这些函数接受基本数据类型作为参数,如果用户定义的类型实现了__bool__
、__str__
等方法,则也可以接受。
原子
原子是表达式的基本组成部分。
atom ::= identifier | literal | enclosure enclosure ::= parenth_form | list_display | dict_display
标识符
TorchScript 中合法标识符的规则与 Python 的规则相同,具体可以参考其 对应文档。
字面量
literal ::= stringliteral | integer | floatnumber
字面量的求值会产生具有特定值(必要时对浮点数进行近似处理)的适当类型对象。字面量是不可变的,相同的字面量多次求值可能会得到同一个对象或具有相同值的不同对象。字符串字面量、整数字面量 和 浮点数字面量 的定义与 Python 中的定义相同。
带括号的表达形式
parenth_form ::= '(' [expression_list] ')'
带圆括号的表达式列表会生成该表达式列表的结果。如果列表中至少包含一个逗号,则生成一个Tuple
; 否则,它将产生表达式列表内的单一表达式。空的圆括号对会产生一个空的Tuple
对象(Tuple[]
)。
列表和字典展示
list_comprehension ::= expression comp_for comp_for ::= 'for' target_list 'in' or_expr list_display ::= '[' [expression_list | list_comprehension] ']' dict_display ::= '{' [key_datum_list | dict_comprehension] '}' key_datum_list ::= key_datum (',' key_datum)* key_datum ::= expression ':' expression dict_comprehension ::= key_datum comp_for
列表和字典可以通过两种方式构建:一是显式列出容器内容,二是通过一组循环指令(即推导式)来计算它们。推导式的语义等同于使用for循环并向正在生成的列表中追加元素。推导式会隐式地创建自己的作用域,以确保目标列表中的项目不会泄露到封闭的作用域中。如果容器项显式列出,则表达式列表中的表达式从左到右进行评估。在dict_display
具有key_datum_list
的情况下,如果有重复的键,生成的字典将使用最右侧数据中该重复键对应的值。
主要项目
primary ::= atom | attributeref | subscription | slicing | call
属性引用
attributeref ::= primary '.' identifier
primary
必须是一个包含 identifier
属性的对象。
订阅
subscription ::= primary '[' expression_list ']'
primary
必须是一个支持订阅的对象。
-
如果主变量是
List
、Tuple
或str
,表达式列表必须是一个整数或切片。 -
如果主对象是
Dict
,则表达式列表必须评估为与Dict
的键类型相同的对象。 -
如果主模块是
ModuleList
,则表达式列表必须是整数
字面量。 -
如果主模块是
ModuleDict
,则表达式必须是stringliteral
。
切片
切片用于选择str
、Tuple
、List
或 Tensor
中的一系列项目。切片可以作为赋值语句或 del
语句中的表达式或目标。
slicing ::= primary '[' slice_list ']' slice_list ::= slice_item (',' slice_item)* [','] slice_item ::= expression | proper_slice proper_slice ::= [expression] ':' [expression] [':' [expression] ]
包含多个切片项的切片列表只能与评估为类型 Tensor
对象的主变量一起使用。
呼叫
call ::= primary '(' argument_list ')' argument_list ::= args [',' kwargs] | kwargs args ::= [arg (',' arg)*] kwargs ::= [kwarg (',' kwarg)*] kwarg ::= arg '=' expression arg ::= identifier
primary
必须是一个可调用的对象。在尝试调用之前,会先求值所有的参数表达式。
幂运算符
power ::= primary ['**' u_expr]
幂运算符具有与内置的 pow 函数相同的语义(不支持);它计算左操作数以右操作数为指数的结果。它的绑定优先级高于左侧的一元运算符,但低于右侧的一元运算符;即 -2 ** -3 == -(2 ** (-3))
。左右操作数可以是 int
, float
或 Tensor
类型。在标量与张量或张量与标量的指数运算中,会进行广播处理;而在张量与张量之间的元素级指数运算中,则不会进行任何广播。
Unary和Arithmetic Bitwise操作
u_expr ::= power | '-' power | '~' power
一元 -
运算符返回其参数的相反数。一元 ~
运算符返回其参数的按位取反。-
可以用于 int
、float
和类型为 int
或 float
的 Tensor
。而 ~
运算符只能用于 int
和类型为 int
的 Tensor
。
二进制 arithmetic 运算
由于需要符合中文习惯且保持原意,这里将 "arithmetic" 直接翻译为“算术”,因此更正如下:二进制算术运算
既然原文已经是自然通顺的表达方式,则直接返回原文即可。二进制算术运算
m_expr ::= u_expr | m_expr '*' u_expr | m_expr '@' m_expr | m_expr '//' u_expr | m_expr '/' u_expr | m_expr '%' u_expr a_expr ::= m_expr | a_expr '+' m_expr | a_expr '-' m_expr
二元算术运算符可以操作 Tensor
、int
和 float
。对于张量与张量的操作,两个参数必须具有相同的形状。对于标量与张量或张量与标量的操作,标量通常会被广播到与张量相同大小。除法运算符只能接受标量作为其右操作数,并不支持广播。@
运算符用于矩阵乘法,仅在 Tensor
参数上运行。乘法运算符 (*
) 可以与列表和整数一起使用,以便获得重复一定次数的原始列表。
操作调整
shift_expr ::= a_expr | shift_expr ( '<<' | '>>' ) a_expr
这些操作符可以接受以下参数组合:两个 int
参数、两个 Tensor
参数,或者一个 Tensor
参数和一个 int
或者 float
标量参数。在所有情况下,右移操作定义为地板除以 pow(2, n)
,左移操作定义为乘以 pow(2, n)
。当两个参数都是 Tensor
时,它们必须具有相同的形状。如果一个参数是标量而另一个是 Tensor
,则该标量会被广播以匹配 Tensor
的大小。
二进制位运算
and_expr ::= shift_expr | and_expr '&' shift_expr xor_expr ::= and_expr | xor_expr '^' and_expr or_expr ::= xor_expr | or_expr '|' xor_expr
The &
operator calculates the bitwise AND of its arguments, the ^
calculates the bitwise XOR, and the |
calculates the bitwise OR. Both operands must be either int
or Tensor
. Alternatively, if one operand is a Tensor
and the other is an int
, then the left operand must be the Tensor
and the right operand must be the int
. When both operands are Tensors
, they need to have the same shape. If one operand is an int
and the other is a Tensor
, the int
value will be broadcast logically to match the shape of the Tensor
.
比较
comparison ::= or_expr (comp_operator or_expr)* comp_operator ::= '<' | '>' | '==' | '>=' | '<=' | '!=' | 'is' ['not'] | ['not'] 'in'
比较会产生一个布尔值(True
或 False
),或者如果其中一个操作数是 Tensor
,则会生成一个布尔型的 Tensor
。比较可以任意地链接起来,只要它们不会生成包含多个元素的布尔型 Tensor
即可。a op1 b op2 c ...
等价于 a op1 b and b op2 c and ...
。
值比较
运算符 <
、>
、==
、>=
、<=
和 !=
用于比较两个对象的值。这两个对象通常需要是相同类型,除非它们之间存在隐式类型转换。如果用户定义的类型上实现了丰富的比较方法(例如 __lt__
),则可以对这些类型进行比较。内置类型的比较方式与 Python 类似:
-
对数字进行数学比较。
-
字符串按照字典顺序进行比较。
-
lists
、tuples
和dicts
只能与其他相同类型的lists
、tuples
和dicts
进行比较,并使用相应元素的比较运算符来进行比较。
会员测试操作
操作符 in
和 not in
用于测试成员资格。如果 x
是集合 s
的一个成员,那么 x in s
将评估为 True
,否则为 False
。表达式 x not in s
等同于 not x in s
。此操作符支持 lists
、dicts
和 tuples
,并且如果用户定义的类型实现了 __contains__
方法,则也可以与此操作符一起使用。
身份比较
对于类型int
、double
、bool
和torch.device
之外的所有类型,操作符is
和 is not
用于测试对象的身份;如果且仅如果x
和y
是同一个对象,则x is y
为True
。对于其他所有类型,is
等同于使用==
进行比较。x is not y
返回x is y
的相反值。
布尔运算
or_test ::= and_test | or_test 'or' and_test and_test ::= not_test | and_test 'and' not_test not_test ::= 'bool' '(' or_expr ')' | comparison | 'not' not_test
用户定义的对象可以通过实现__bool__
方法来自定义转换为bool
的方式。操作符not
在其操作数为假时返回True
,否则返回False
。表达式x
and y
首先评估x
; 如果x
为False
, 则直接返回False
; 否则,继续评估y
并返回其值 (False
或 True
)。表达式x
or y
首先评估x
; 如果x
为True
, 则直接返回True
; 否则,继续评估y
并返回其值 (False
或 True
)。
条件表达式
conditional_expression ::= or_expr ['if' or_test 'else' conditional_expression] expression ::= conditional_expression
表达式 x if c else y
首先评估条件 c
。如果 c
为 True
,则计算 x
并返回其值;否则,计算 y
并返回其值。与 if 语句类似,x
和 y
必须是相同类型的值。
表达式列表
expression_list ::= expression (',' expression)* [','] starred_item ::= '*' primary
带星号的项只能出现在赋值语句的左边,例如:a, *b, c = ...
。
简单语句
以下部分描述了 TorchScript 中支持的简单语句的语法。这部分内容是参照 Python 语言参考中的简单语句章节 制定的。
表达式语句
expression_stmt ::= starred_expression starred_expression ::= expression | (starred_item ",")* [starred_item] starred_item ::= assignment_expression | "*" or_expr
赋值语句
assignment_stmt ::= (target_list "=")+ (starred_expression) target_list ::= target ("," target)* [","] target ::= identifier | "(" [target_list] ")" | "[" [target_list] "]" | attributeref | subscription | slicing | "*" target
增强赋值语句
augmented_assignment_stmt ::= augtarget augop (expression_list) augtarget ::= identifier | attributeref | subscription augop ::= "+=" | "-=" | "*=" | "/=" | "//=" | "%=" | "**="| ">>=" | "<<=" | "&=" | "^=" | "|="
带有注解的赋值语句
annotated_assignment_stmt ::= augtarget ":" expression ["=" (starred_expression)]
raise 语句
raise_stmt ::= "raise" [expression ["from" expression]]
TorchScript 不支持使用 try\except\finally
的 Raise 语句。
assert 语句
assert_stmt ::= "assert" expression ["," expression]
TorchScript 中的 Assert 语句不支持 try\except\finally
。
return
语句
return_stmt ::= "return" [expression_list]
TorchScript 不支持 try\except\finally
语法。
del
语句
del_stmt ::= "del" target_list
pass
语句
pass_stmt ::= "pass"
print 语句
print_stmt ::= "print" "(" expression [, expression] [.format{expression_list}] ")"
break
语句
break_stmt ::= "break"
继续语句:
continue_stmt ::= "continue"
复合语句
以下部分描述了 TorchScript 支持的复合语句的语法,并强调了 TorchScript 与常规 Python 语句之间的差异。这部分内容是基于 Python 语言参考中的复合语句章节。
if 语句
Torchscript 同时支持基本的 if/else
语句和三元 if/else
表达式。
基本的 if/else
语句
if_stmt ::= "if" assignment_expression ":" suite ("elif" assignment_expression ":" suite) ["else" ":" suite]
elif
语句可以在 else
语句之前重复任意次数。
三元 if/else
语句
if_stmt ::= return [expression_list] "if" assignment_expression "else" [expression_list]
示例 1
一个一维的张量会被转换为布尔类型:
import torch @torch.jit.script def fn(x: torch.Tensor): if x: # The tensor gets promoted to bool return True return False print(fn(torch.rand(1)))
上述示例的输出如下:
True
示例 2
一个多维的 tensor
不会转换为 bool
:
import torch # Multi dimensional Tensors error out. @torch.jit.script def fn(): if torch.rand(2): print("Tensor is available") if torch.rand(4,5,6): print("Tensor is available") print(fn())
运行上述代码会引发以下 RuntimeError
。
RuntimeError: The following operation failed in the TorchScript interpreter. Traceback of TorchScript (most recent call last): @torch.jit.script def fn(): if torch.rand(2): ~~~~~~~~~~~~ <--- HERE print("Tensor is available") RuntimeError: Boolean value of Tensor with more than one value is ambiguous
如果一个条件变量被标注为 final
,那么会根据该条件变量的值来决定是否计算 true 分支或 false 分支。
示例 3
在这个例子中,只有 True 分支会被评估,因为变量 a
被标注为 final
并且被设置为 True
:
import torch a : torch.jit.final[Bool] = True if a: return torch.empty(2,3) else: return []
while 语句
while_stmt ::= "while" assignment_expression ":" suite
while…else 语句在 Torchscript 中不受支持,会引发 RuntimeError
错误。
for-in 语句
for_stmt ::= "for" target_list "in" expression_list ":" suite ["else" ":" suite]
for...else
语句在 Torchscript 中不受支持,会引发 RuntimeError
错误。
示例 1
对于元组的for循环:会展开成多个循环,为元组中的每个元素生成一个独立的循环体。每个循环体都必须通过类型的检验。
import torch from typing import Tuple @torch.jit.script def fn(): tup = (3, torch.ones(4)) for x in tup: print(x) fn()
上述示例的输出如下:
3 1 1 1 1 [ CPUFloatType{4} ]
示例 2
对于列表的for循环:对nn.ModuleList
进行for循环时,编译器会在编译阶段展开循环体,处理列表中的每一个元素。
class SubModule(torch.nn.Module): def __init__(self): super().__init__() self.weight = nn.Parameter(torch.randn(2)) def forward(self, input): return self.weight + input class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.mods = torch.nn.ModuleList([SubModule() for i in range(10)]) def forward(self, v): for module in self.mods: v = module(v) return v model = torch.jit.script(MyModule())
with 语句
with
语句用于使用上下文管理器定义的方法来包裹并执行一段代码。
with_stmt ::= "with" with_item ("," with_item) ":" suite with_item ::= expression ["as" target]
-
如果目标包含在
with
语句中,上下文管理器的__enter__()
方法的返回值会被赋给它。与 Python 不同的是,如果异常导致代码块退出,则不会将异常类型、值和回溯信息作为参数传递给__exit__()
方法。而是提供三个None
参数。 -
try
、except
和finally
语句不能在with
块中使用。 -
在
with
块中引发的异常无法被抑制。
tuple
语句
tuple_stmt ::= tuple([iterables])
-
TorchScript 中的可迭代类型包括:
Tensors
、lists
、tuples
、dictionaries
、strings
、torch.nn.ModuleList
和torch.nn.ModuleDict
。 -
你无法使用该内置函数将列表转换为元组。
将所有输出解包为元组的相关内容如下:
abc = func() # Function that returns a tuple a,b = func()
getattr 语句
getattr_stmt ::= getattr(object, name[, default])
-
属性名称必须是具体的字符串。
-
不支持模块类型对象(例如 torch._C)。
-
不支持自定义类对象(例如 torch.classes.*)。
hasattr 语句
hasattr_stmt ::= hasattr(object, name)
-
属性名称必须是具体的字符串。
-
不支持模块类型对象(例如 torch._C)。
-
不支持自定义类对象(例如 torch.classes.*)。
zip
语句
zip_stmt ::= zip(iterable1, iterable2)
-
参数必须是可迭代的对象。
-
支持两种具有相同外部容器类型但长度不同的可迭代对象。
示例 1
a = [1, 2] # List b = [2, 3, 4] # List zip(a, b) # works
示例 2
此示例失败的原因是迭代器的容器类型不同:
a = (1, 2) # Tuple b = [2, 3, 4] # List zip(a, b) # Runtime error
运行上述代码会引发以下 RuntimeError
。
RuntimeError: Can not iterate over a module list or tuple with a value that does not have a statically determinable length.
示例 3
支持两种容器类型相同但数据类型不同的可迭代对象:
a = [1.3, 2.4] b = [2, 3, 4] zip(a, b) # Works
TorchScript 中的可迭代类型包括:Tensors
、lists
、tuples
、dictionaries
、strings
、torch.nn.ModuleList
和 torch.nn.ModuleDict
。
enumerate 语句
enumerate_stmt ::= enumerate([iterable])
-
参数必须是可迭代的对象。
-
TorchScript 中的可迭代类型包括:
Tensors
、lists
、tuples
、dictionaries
、strings
、torch.nn.ModuleList
和torch.nn.ModuleDict
。
Python价值观
解析规则
当给定一个 Python 值时,TorchScript 会尝试通过以下五种不同方式进行解析:
-
- 编译型Python实现:
-
-
当Python值由TorchScript能够编译的Python实现支持时,TorchScript会编译并使用该底层Python实现。
-
示例:
torch.jit.Attribute
-
-
- Python封装操作:
-
-
当一个Python值是由原生PyTorch操作符封装而成时,TorchScript会生成相应的操作符。
-
示例:
torch.jit._logging.add_stat_value
-
-
- Python 对象身份匹配:
-
-
TorchScript 对一组有限的
torch.*
API 调用(以 Python 值形式)进行支持,并尝试将这些调用中的每一个与给定的 Python 值进行匹配。 -
当匹配时,TorchScript 会生成一个对应的
SugaredValue
实例,并包含这些值的转换逻辑。 -
示例:
torch.jit.isinstance()
-
-
- 名称匹配:
-
-
对于 Python 内置的函数和常量,TorchScript 通过名称进行识别,并创建一个对应的
SugaredValue
实例来实现它们的功能。 -
示例:
all()
-
-
- 值快照:
-
-
对于来自未知模块的 Python 值,TorchScript 会尝试对其进行快照,并将其转换为函数或方法编译过程中图中的常量。
-
示例:
math.pi
-
Python 内置函数支持
内置函数 |
支持级别 |
注释 |
---|---|---|
|
部分 |
仅支持 |
|
完整 |
|
|
完整 |
|
|
无内容 |
|
|
部分 |
仅支持 |
|
部分 |
仅支持 |
|
无内容 |
|
|
无内容 |
|
|
无内容 |
|
|
无内容 |
|
|
部分 |
只支持 ASCII 字符集。 |
|
完整 |
|
|
无内容 |
|
|
无内容 |
|
|
无内容 |
|
|
完整 |
|
|
无内容 |
|
|
完整 |
|
|
完整 |
|
|
无内容 |
|
|
无内容 |
|
|
无内容 |
|
|
部分 |
不支持 |
|
部分 |
不支持手动索引指定。| 不支持格式类型修饰符。 |
|
无内容 |
|
|
部分 |
属性名必须是字符串字面量。 |
|
无内容 |
|
|
部分 |
属性名必须是字符串字面量。 |
|
完整 |
|
|
部分 |
仅支持 |
|
完整 |
仅支持 |
|
无内容 |
|
|
部分 |
|
|
完整 |
|
|
无内容 |
|
|
无内容 |
|
|
完整 |
|
|
完整 |
|
|
部分 |
只支持 ASCII 字符集。 |
|
完整 |
|
|
部分 |
|
|
无内容 |
|
|
完整 |
|
|
无内容 |
|
|
无内容 |
|
|
部分 |
|
|
无内容 |
|
|
无内容 |
|
|
完整 |
|
|
部分 |
|
|
完整 |
|
|
部分 |
|
|
完整 |
|
|
部分 |
它只能在 |
|
无内容 |
|
|
无内容 |
|
|
完整 |
|
|
无内容 |
Python 内置值支持
内置值 |
支持级别 |
注释 |
---|---|---|
|
完整 |
|
|
完整 |
|
|
完整 |
|
|
无内容 |
|
|
完整 |
torch.* APIs
远程过程调用
TorchScript 支持一组 RPC API,允许在指定的远程工作者上而非本地运行函数。
具体来说,以下 API 获得全面支持:
-
-
torch.distributed.rpc.rpc_sync()
-
-
rpc_sync()
会阻塞并进行远程过程调用,在远程工作者上运行一个函数。RPC消息的发送和接收与Python代码的执行是并行进行的。 -
有关其用法和示例的更多信息,可以参考
rpc_sync()
。
-
-
-
-
torch.distributed.rpc.rpc_async()
-
-
rpc_async()
发起一个非阻塞的 RPC 调用,用于在远程工作者上运行函数。RPC 消息会与 Python 代码的执行并行地发送和接收。 -
有关其用法和示例的更多信息,可以参考
rpc_async()
。
-
-
-
-
torch.distributed.rpc.remote()
-
-
remote()
在一个工作者上执行远程调用,并返回一个RRef
远程引用。 -
有关其用法和示例的更多信息,可以参考
remote()
。
-
-
异步执行
TorchScript 允许你创建异步计算任务,从而更有效地利用计算资源。这通过提供一组仅在 TorchScript 中使用的 API 来实现。
类型注解
TorchScript 是静态类型的。它提供了一系列工具来帮助标注变量和属性。
-
-
torch.jit.annotate()
-
-
为 TorchScript 提供类型提示,在 Python 3 风格的类型提示不适用的情况下使用。
-
一个常见的例子是对类似
[]
的表达式添加类型注解。默认情况下,[]
被视为List[torch.Tensor]
。当需要不同类型的列表时,可以使用以下代码来提示TorchScript:torch.jit.annotate(List[int], [])
。 -
更多详情请参见
annotate()
-
-
-
-
torch.jit.Attribute
-
-
常见的用例是为
torch.nn.Module
属性提供类型提示。由于它们的__init__
方法不会被TorchScript解析,因此在模块的__init__
方法中应使用torch.jit.Attribute
而不是torch.jit.annotate
。 -
更多详情请参见
Attribute()
-
-
-
-
torch.jit.Final
-
-
Python的
typing.Final
的一个别名。为了保持向后兼容性,保留了torch.jit.Final
。
-
-
元编程
TorchScript 提供了一系列工具,以方便进行元编程。
-
-
torch.jit.is_scripting()
-
-
返回一个布尔值,指示当前程序是否通过
torch.jit.script
进行了编译。 -
当在
assert
或if
语句中使用时,如果torch.jit.is_scripting()
评估结果为False
,则该作用域或分支不会被编译。 -
它的值可以在编译时进行静态评估,因此通常用于
if
语句中,以防止TorchScript编译某个分支。 -
更多细节和示例请参见
is_scripting()
-
-
-
-
torch.jit.is_tracing()
-
-
返回一个布尔值,表示当前程序是否被
torch.jit.trace
或torch.jit.trace_module
追踪。 -
更多细节可以参考
is_tracing()
-
-
-
-
@torch.jit.ignore
-
-
这个装饰器告诉编译器忽略某个函数或方法,并将其保留在 Python 函数形式。
-
这允许你在模型中保留暂时不兼容TorchScript的代码。
-
当一个由
@torch.jit.ignore
装饰的函数在 TorchScript 中被调用时,该函数会将其调用转发给 Python 解释器。 -
包含未被考虑的函数的模型无法导出。
-
更多细节和示例可以在
ignore()
中查看
-
-
类型精化
-
-
torch.jit.isinstance()
-
-
返回一个布尔值,指示变量是否为指定类型。
-
有关其用法和示例的更多信息,可以参考
isinstance()
。
-
-