torch.jit.script
- torch.jit.script(obj, optimize=None, _frames_up=0, _rcb=None, example_inputs=None)[源代码]
-
编写函数脚本。
将一个函数或
nn.Module
脚本化会检查其源代码,使用TorchScript编译器将其转换为TorchScript代码,并返回一个ScriptModule
或ScriptFunction
。TorchScript是Python语言的一个子集,因此并非所有Python功能都可用,但我们提供了足够的功能来处理张量和进行控制依赖操作。完整的指南请参阅TorchScript 语言参考。将字典或列表脚本化会将其数据复制到一个TorchScript实例中。这个实例可以在Python和TorchScript之间通过引用来回传递,无需额外的内存拷贝。
-
torch.jit.script
可以用于模块、函数、字典和列表 -
以及作为装饰器
@torch.jit.script
用于 TorchScript 类和函数。
- 参数
-
-
obj (Callable, class, 或 nn.Module) – 需要编译的对象,可以是
nn.Module
、函数、类类型、字典或列表。 -
example_inputs (Union[List[Tuple], Dict[Callable, List[Tuple]], None]) – 提供示例输入以注释函数或
nn.Module
的参数。
-
- 返回值
-
如果
obj
是nn.Module
,script
将返回一个ScriptModule
对象。该对象将具有与原始nn.Module
相同的子模块和参数集合。如果obj
是一个独立函数,则返回一个ScriptFunction
对象。如果obj
是一个dict
,则script
返回torch._C.ScriptDict 的实例。如果obj
是一个list
,则script
返回torch._C.ScriptList 的实例。
- 编写函数脚本
-
使用
@torch.jit.script
装饰器可以通过编译函数体来生成一个ScriptFunction
。示例(将函数编写为脚本):
import torch @torch.jit.script def foo(x, y): if x.max() > y.max(): r = x else: r = y return r print(type(foo)) # torch.jit.ScriptFunction # See the compiled graph as Python code print(foo.code) # Call the function using the TorchScript interpreter foo(torch.ones(2, 2), torch.ones(2, 2))
- 使用 example_inputs 编写函数脚本
-
示例输入可以用来标记函数参数。
示例(在编写脚本之前标注函数):
import torch def test_sum(a, b): return a + b # Annotate the arguments to be int scripted_fn = torch.jit.script(test_sum, example_inputs=[(3, 4)]) print(type(scripted_fn)) # torch.jit.ScriptFunction # See the compiled graph as Python code print(scripted_fn.code) # Call the function using the TorchScript interpreter scripted_fn(20, 100)
- 创建nn.Module脚本
-
默认情况下,通过脚本编写一个
nn.Module
会编译其forward
方法,并递归地编译由forward
调用的所有方法、子模块和函数。如果一个nn.Module
仅使用了TorchScript支持的功能,则无需对原始代码进行任何修改。script
将构建ScriptModule
,该模块包含原始模块的所有属性、参数和方法的副本。示例(使用包含参数的简单模块进行脚本编写):
import torch class MyModule(torch.nn.Module): def __init__(self, N, M): super().__init__() # This parameter will be copied to the new ScriptModule self.weight = torch.nn.Parameter(torch.rand(N, M)) # When this submodule is used, it will be compiled self.linear = torch.nn.Linear(N, M) def forward(self, input): output = self.weight.mv(input) # This calls the `forward` method of the `nn.Linear` module, which will # cause the `self.linear` submodule to be compiled to a `ScriptModule` here output = self.linear(output) return output scripted_module = torch.jit.script(MyModule(2, 3))
示例(带有追踪子模块的脚本化模块):
import torch import torch.nn as nn import torch.nn.functional as F class MyModule(nn.Module): def __init__(self) -> None: super().__init__() # torch.jit.trace produces a ScriptModule's conv1 and conv2 self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16)) self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16)) def forward(self, input): input = F.relu(self.conv1(input)) input = F.relu(self.conv2(input)) return input scripted_module = torch.jit.script(MyModule())
要编译除
forward
之外的方法及其调用的其他方法,请在该方法上添加@torch.jit.export
装饰器。若不想进行编译,则可以使用@torch.jit.ignore
或@torch.jit.unused
。示例(模块中导出但被忽略的方法):
import torch import torch.nn as nn class MyModule(nn.Module): def __init__(self) -> None: super().__init__() @torch.jit.export def some_entry_point(self, input): return input + 10 @torch.jit.ignore def python_only_fn(self, input): # This function won't be compiled, so any # Python APIs can be used import pdb pdb.set_trace() def forward(self, input): if self.training: self.python_only_fn(input) return input * 99 scripted_module = torch.jit.script(MyModule()) print(scripted_module.some_entry_point(torch.randn(2, 2))) print(scripted_module(torch.randn(2, 2)))
示例(使用 example_inputs 对 nn.Module 的前向方法进行注解):
import torch import torch.nn as nn from typing import NamedTuple class MyModule(NamedTuple): result: List[int] class TestNNModule(torch.nn.Module): def forward(self, a) -> MyModule: result = MyModule(result=a) return result pdt_model = TestNNModule() # Runs the pdt_model in eager model with the inputs provided and annotates the arguments of forward scripted_model = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], }) # Run the scripted_model with actual inputs print(scripted_model([20]))
-