torch.cuda.make_graphed_callables

torch.cuda.make_graphed_callables(callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None)[源代码]

接受可调用对象(如函数或nn.Module)并返回其图形表示版本。

每个绘制的可调用对象的前向传递会在一个自动微分节点内,将其源可调用对象的前向 CUDA 操作以 CUDA 图的形式运行。

可调用对象的正向传递还会将一个反向节点添加到自动微分图中。在反向传递过程中,该节点会作为 CUDA 图来执行可调用对象的反向操作。

因此,每个绘制的可调用函数都应在支持自动微分的训练循环中替换其原始可调用函数。

请参阅部分网络捕获以了解详细用法和限制条件。

如果你传递一个包含多个可调用对象的元组,它们的捕获会共享同一个内存池。关于这种做法何时适用,请参见图内存管理

参数
  • callables (torch.nn.ModulePython 函数,或 元组 包含这些项中的一个或多个) – 可调用对象或可调用对象的集合。有关何时传递包含可调用对象的元组的信息,请参阅 图内存管理。如果你传递了一个包含可调用对象的元组,那么这些可调用对象在元组中的顺序必须与它们在实际工作负载中运行时的顺序相同。

  • sample_args (元组 of Tensors, 或 元组 of 元组 of Tensors) – 每个可调用对象的样本参数。如果传递了一个单一的可调用对象,sample_args 应该是一个包含参数张量的单个元组。如果传递了一个可调用对象的元组,sample_args 应该是包含参数张量的元组组成的元组。

  • num_warmup_iters (int) – 热身迭代的次数。目前,DataDistributedParallel 需要 11 次迭代来进行热身,默认值为 3

  • allow_unused_input (bool) – 如果设置为 False,指定在计算输出时未使用的输入(因此它们的梯度始终为零)将被视为错误。默认值为 False。

  • pool (可选) – 一个令牌(由 graph_pool_handle()other_Graph_instance.pool() 返回),用于提示此图可能与指定的内存池共享内存。详情请参阅图内存管理

注意

每个张量的 requires_grad 状态必须与其在 sample_args 中对应的实际输入在训练循环中的预期状态相匹配。

警告

此 API 处于 beta 阶段,未来版本可能有所更改。

警告

sample_args 对于每个可调用对象,必须仅包含张量。不允许其他类型的对象。

警告

返回的可调用对象不支持高阶微分(例如,双重反向微分)。

警告

在传递给make_graphed_callables()的任何Module中,只有参数可以是可训练的。缓冲区必须设置requires_grad=False

警告

在将一个 torch.nn.Module 通过 make_graphed_callables() 处理之后,你不能添加或移除该 Module 的任何参数或缓冲区。

警告

torch.nn.Module 传递给 make_graphed_callables() 时,不能在这些模块上注册挂钩。但是,在通过 make_graphed_callables() 传递之后再对它们进行挂钩注册是允许的。

警告

当运行一个图化的可调用对象时,必须按该可调用对象的sample_args中指定的顺序和格式传递参数。

警告

自动混合精度仅在禁用缓存的情况下受支持于make_graphed_callables()。使用上下文管理器torch.cuda.amp.autocast()时,必须设置cache_enabled=False

本页目录