torch.jit.fork

torch.jit.fork(func, *args, **kwargs)[源代码]

创建一个异步任务来执行 func,并引用该任务执行结果的值。

fork 将立即返回,因此func的返回值可能还未计算完成。要强制完成任务并获取返回值,请在 Future 上调用 torch.jit.waitfork 调用带有返回类型为 Tfunc 时,其类型定义为 torch.jit.Future[T]fork 调用可以任意嵌套,并且可以用位置参数和关键字参数调用。异步执行仅在 TorchScript 中运行时发生;如果在纯 Python 中运行,则fork不会并行执行。fork 在记录跟踪时被调用也不会并行执行,但 forkwait 调用将被捕获到导出的 IR 图中。

警告

fork 任务将以非确定性方式执行。我们建议只为那些不修改输入、模块属性或全局状态的纯函数创建并行 fork 任务。

参数
  • func (callabletorch.nn.Module) – 一个将被调用的 Python 函数或 torch.nn.Module。如果在 TorchScript 中执行,它将以异步方式运行;否则不会。fork 的追踪调用将在 IR 中被捕获。

  • *args - 用于调用func的参数。

  • **kwargs - 调用func时使用的参数。

返回值

func 的执行引用。值 T 只能通过使用 torch.jit.wait 强制完成 func 来获取。

返回类型

torch.jit.Future[T]

示例(创建一个自由函数的分支):

import torch
from torch import Tensor
def foo(a : Tensor, b : int) -> Tensor:
    return a + b
def bar(a):
    fut : torch.jit.Future[Tensor] = torch.jit.fork(foo, a, b=2)
    return torch.jit.wait(fut)
script_bar = torch.jit.script(bar)
input = torch.tensor(2)
# only the scripted version executes asynchronously
assert script_bar(input) == bar(input)
# trace is not run asynchronously, but fork is captured in IR
graph = torch.jit.trace(bar, (input,)).graph
assert "fork" in str(graph)

示例( fork 一个模块方法):

import torch
from torch import Tensor
class AddMod(torch.nn.Module):
    def forward(self, a: Tensor, b : int):
        return a + b
class Mod(torch.nn.Module):
    def __init__(self) -> None:
        super(self).__init__()
        self.mod = AddMod()
    def forward(self, input):
        fut = torch.jit.fork(self.mod, a, b=2)
        return torch.jit.wait(fut)
input = torch.tensor(2)
mod = Mod()
assert mod(input) == torch.jit.script(mod).forward(input)
本页目录