torch.jit.annotate
- torch.jit.annotate(the_type, the_value)[源代码]
-
用于在TorchScript编译器中为the_value指定类型。
此方法是一个传递函数,用于返回the_value,并告知TorchScript编译器the_value的类型。在非TorchScript环境下运行时,该方法不会执行任何操作。
虽然TorchScript可以为大多数Python表达式正确推断类型,但在某些情况下可能会出现错误,包括:
-
空的容器如[]和{},TorchScript 认为它们是存储Tensor类型的容器
-
类似于Optional[T]的可选类型,如果被赋予了类型为T的有效值,TorchScript 将会认为它的类型是T而不是Optional[T]
注意,annotate() 在 torch.nn.Module 子类的 __init__ 方法中不起作用,因为它是以即时执行模式运行的。要为 torch.nn.Module 属性注解类型,请使用
Attribute()
。示例:
import torch from typing import Dict @torch.jit.script def fn(): # Telling TorchScript that this empty dictionary is a (str -> int) dictionary # instead of default dictionary type of (str -> Tensor). d = torch.jit.annotate(Dict[str, int], {}) # Without `torch.jit.annotate` above, following statement would fail because of # type mismatch. d["name"] = 20
- 参数
-
-
the_type – 用于 the_value 的 Python 类型,应作为类型提示传递给 TorchScript 编译器
-
the_value - 用于提示类型的值或表达式。
-
- 返回值
-
the_value 作为返回值传递回去。
-