torch.jit.annotate

torch.jit.annotate(the_type, the_value)[源代码]

用于在TorchScript编译器中为the_value指定类型。

此方法是一个传递函数,用于返回the_value,并告知TorchScript编译器the_value的类型。在非TorchScript环境下运行时,该方法不会执行任何操作。

虽然TorchScript可以为大多数Python表达式正确推断类型,但在某些情况下可能会出现错误,包括:

  • 空的容器如[]{},TorchScript 认为它们是存储Tensor类型的容器

  • 类似于Optional[T]的可选类型,如果被赋予了类型为T的有效值,TorchScript 将会认为它的类型是T而不是Optional[T]

注意,annotate()torch.nn.Module 子类的 __init__ 方法中不起作用,因为它是以即时执行模式运行的。要为 torch.nn.Module 属性注解类型,请使用Attribute()

示例:

import torch
from typing import Dict

@torch.jit.script
def fn():
    # Telling TorchScript that this empty dictionary is a (str -> int) dictionary
    # instead of default dictionary type of (str -> Tensor).
    d = torch.jit.annotate(Dict[str, int], {})

    # Without `torch.jit.annotate` above, following statement would fail because of
    # type mismatch.
    d["name"] = 20
参数
  • the_type – 用于 the_value 的 Python 类型,应作为类型提示传递给 TorchScript 编译器

  • the_value - 用于提示类型的值或表达式。

返回值

the_value 作为返回值传递回去。

本页目录