torch.utils.checkpoint
注意
检查点通过在反向传播过程中重新运行每个已保存段的前向传递来实现。这可能导致持久状态(如随机数生成器 (RNG) 状态)比没有启用检查点时更先进。默认情况下,检查点包含逻辑以调整 RNG 状态,使得使用 RNG 的检查点传递(例如通过 dropout 使用)与非检查点传递相比具有确定性输出。根据已保存操作的运行时间,存储和恢复 RNG 状态可能会带来一定的性能损失。如果不需要与非检查点传递相比具有确定性输出,则可以通过向 checkpoint
或 checkpoint_sequential
传递参数 preserve_rng_state=False
来省略在每个检查点期间存储和恢复 RNG 状态。
暂存逻辑会保存和恢复CPU及其他设备类型的随机数生成器状态(通过_infer_device_type
从张量参数中推断设备类型,不包括CPU张量)到run_fn
。如果有多个设备,则只会为单一类型的设备保存设备状态,其余设备将被忽略。因此,如果检查点函数涉及随机性,可能会导致不正确的梯度。(注意,如果检测到CUDA设备,则会优先选择;否则,会选择遇到的第一个设备。)如果没有CPU张量,默认设备类型的状态(默认值是cuda,并且可以通过DefaultDeviceType
设置为其他设备)将被保存和恢复。然而,逻辑无法预测用户是否会在run_fn
内部将张量移动到新设备上。“新”设备指的是不属于[当前设备 + 张量参数的设备]集合中的设备。因此,在run_fn
中如果将张量移动到新设备,与非检查点传递相比,确定性输出永远不会得到保证。
- torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=None, context_fn=<function noop_context_fn>, determinism_check='default', debug=False, **kwargs)[源代码]
-
为模型或模型的一部分创建检查点。
激活检查点是一种通过牺牲计算资源来节省内存的技术。在反向传播过程中,不需要将用于梯度计算的中间结果(张量)一直保存到实际使用时,而是省略了这些中间结果的存储,并在需要时重新计算它们。这种技术可以应用于模型中的任何部分。
目前有两种检查点实现可供选择,取决于
use_reentrant
参数的设置。建议使用use_reentrant=False
。请参阅下方的注释以了解它们之间的差异。警告
如果在反向传递过程中
function
的调用与正向传递不同(例如由于全局变量),检查点版本可能不再等价,从而可能导致错误被触发或导致静默的梯度计算问题。警告
应显式传递
use_reentrant
参数。在版本 2.4 中,如果没有传递此参数,则会抛出异常。如果你使用的是use_reentrant=True
,请参阅下方的重要注意事项和潜在限制。注意
当
use_reentrant=True
时,checkpoint 是可重入的;当use_reentrant=False
时,则是非可重入的。这两种变体在以下几个方面存在差异:-
非再入检查点会在所有必需的中间激活被重新计算完成后立即停止进一步的重新计算。此功能默认开启,但可以通过调用
set_checkpoint_early_stop()
来禁用。而再入检查点则会在反向传递期间完整地重新计算整个function
。 -
重新进入变体在前向传递过程中不会记录自动微分图,因为它是在
torch.no_grad()
下运行的。而非重新进入版本会记录自动微分图,在检查点区域内部可以进行反向传播。 -
重新进入检查点仅支持
torch.autograd.backward()
API 的反向传播(不带inputs参数),而非重新进入版本则支持所有方式进行反向传递。 -
至少需要一个输入或输出设置
requires_grad=True
,才能使用重新进入变体。如果不满足此条件,模型中的检查点部分将无法计算梯度。非重新进入版本则没有这一限制。 -
重新进入版本不会将嵌套结构中的张量(如自定义对象、列表、字典等)纳入自动微分的范围,而非重新进入版本则会将其纳入。
-
重新进入检查点不支持包含与计算图分离的张量的检查点区域,而非重新进入版本则支持。对于重新进入变体,如果检查点段中包含使用
detach()
或torch.no_grad()
分离的张量,则反向传播会引发错误。这是因为checkpoint
使所有输出都需要梯度,这会导致模型中定义某些张量没有梯度时出现问题。为了避免这种情况,请在调用checkpoint
函数之前分离这些张量。
- 参数
-
-
函数 – 描述在模型或其部分的前向传递中要执行的操作。它还应知道如何处理以元组形式传入的参数。例如,在 LSTM 中,如果用户传递
(activation, hidden)
,function
应该正确地将第一个输入识别为activation
,第二个输入识别为hidden
-
preserve_rng_state (bool, 可选) – 在每个检查点期间不保存和恢复随机数生成器的状态。需要注意的是,在使用 torch.compile 时,此标志无效,并且系统会始终保留随机数生成器状态。默认值:
True
-
use_reentrant (bool) – 指定是否使用需要重新进入自动微分的激活检查点变体。此参数应显式传递,否则在版本 2.5 中会抛出异常。如果
use_reentrant=False
,checkpoint
将使用不需要重新进入自动微分的实现,并支持更多功能,例如与torch.autograd.grad
正常配合工作以及将关键字参数传递给检查点函数。 -
context_fn (Callable, 可选) – 返回两个上下文管理器元组的可调用对象。该函数及其重新计算将分别在第一个和第二个上下文管理器下运行。此参数仅在
use_reentrant=False
时受支持。 -
determinism_check (str, 可选) – 指定要执行的确切性检查的字符串。默认值为
"default"
,它会比较重新计算张量的形状、数据类型和设备与保存的张量是否一致。若要关闭此检查,请指定"none"
。目前仅支持这两种值。如果您希望看到更多确切性检查,请打开一个问题。如果use_reentrant=False
,则支持该参数;如果use_reentrant=True
,则始终禁用确切性检查。 -
debug (bool, 可选) – 如果为
True
,错误消息将包括在原始前向计算期间运行的操作符跟踪以及重新计算的跟踪。此参数仅在use_reentrant=False
时受支持。 -
args – 包含
函数
输入的元组
-
- 返回值
-
代码
function
在参数*args
上的运行结果
-
- torch.utils.checkpoint.checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs)[源代码]
-
为序列模型设置检查点以节省内存。
顺序模型按顺序执行模块/函数列表。因此,我们可以将此类模型划分为多个段,并为每一段创建检查点。除了最后一个段之外,其他所有段都不会存储中间激活值。每个已检查点的段的输入将在反向传递中保存以重新运行该段。
警告
应显式传递
use_reentrant
参数。在版本 2.4 中,如果未传递该参数,则会抛出异常。如果你使用的是use_reentrant=True
变体,请参阅 :func:`~torch.utils.checkpoint.checkpoint` 以了解此变体的重要注意事项和限制。建议使用use_reentrant=False
。- 参数
-
-
functions – 一个
torch.nn.Sequential
或者是组成模型的模块和函数的列表(按顺序运行)。 -
segments - 模型中创建的块的数量
-
input – 传递给
functions
的输入张量 -
preserve_rng_state (bool, optional) – 在每个检查点期间不保存和恢复随机数生成器的状态。默认值:
True
-
use_reentrant (bool) – 指定是否使用需要重新进入自动微分的激活检查点变体。此参数应显式传递,否则在版本 2.5 中会抛出异常。如果
use_reentrant=False
,checkpoint
将使用不需要重新进入自动微分的实现,并支持更多功能,例如与torch.autograd.grad
正常配合工作以及将关键字参数传递给检查点函数。
-
- 返回值
-
在
*inputs
上顺序运行functions
的输出结果
示例
>>> model = nn.Sequential(...) >>> input_var = checkpoint_sequential(model, chunks, input_var)
- torch.utils.checkpoint.set_checkpoint_debug_enabled(enabled)[源代码]
-
这是一个上下文管理器,用于控制在运行时检查点是否应打印额外的调试信息。有关
checkpoint()
中debug
参数的更多信息,请参阅相关文档。需要注意的是,在设置此上下文管理器时,它会覆盖传递给checkpoint()
的debug
值。如果希望使用本地设置,则可以将None
传递给该上下文。- 参数
-
enabled (bool) – 是否开启调试信息的打印。默认值为 ‘None’。
- 类torch.utils.checkpoint.CheckpointPolicy(value)[源代码]
-
用于指定反向传播过程中检查点策略的枚举。
以下策略受到支持:
-
{MUST,PREFER}_SAVE
: 操作的输出将在正向传播时保存,在反向传播时不再重新计算。 -
{MUST,PREFER}_RECOMPUTE
: 在前向传递中不保存操作的输出,在反向传递时会重新计算。
使用
MUST_*
而不是PREFER_*
来表明该策略不应被其他子系统(如torch.compile)覆盖。注意
一个总是返回
PREFER_RECOMPUTE
的策略函数等同于常规的检查点功能。一个每次操作都返回
PREFER_SAVE
的策略函数,并不代表不使用检查点。使用这种策略会保存超出实际用于梯度计算所需之外的额外张量。 -
- 类torch.utils.checkpoint.SelectiveCheckpointContext(*, is_recompute)[源代码]
-
在选择性检查点过程中传递给策略函数的上下文信息。
此类用于在选择性检查点过程中将相关元数据传递给策略函数。元数据包括当前调用策略函数时是否处于重新计算阶段。
示例
>>> >>> def policy_fn(ctx, op, *args, **kwargs): >>> print(ctx.is_recompute) >>> >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) >>> >>> out = torch.utils.checkpoint.checkpoint( >>> fn, x, y, >>> use_reentrant=False, >>> context_fn=context_fn, >>> )
- torch.utils.checkpoint.create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False)[源代码]
-
帮助在激活检查点期间避免重新计算某些操作。
使用torch.utils.checkpoint.checkpoint来控制反向传递期间重新计算的操作。
- 参数
-
-
policy_fn_or_list (Callable 或 List) –
-
如果提供了策略函数,它应接受一个
SelectiveCheckpointContext
、操作的OpOverload
以及传递给操作的args
和kwargs
,并返回一个CheckpointPolicy
枚举值,指示是否应重新计算该操作的执行。 -
如果提供了操作列表,则相当于该策略对指定的操作返回CheckpointPolicy.MUST_SAVE,而对所有其他操作返回CheckpointPolicy.PREFER_RECOMPUTE。
-
-
allow_cache_entry_mutation (bool, optional) – 默认情况下,如果由选择性激活检查点缓存的任何张量被修改,则会引发错误以确保正确性。若设置为True,则禁用此检查。
-
- 返回值
-
由两个上下文管理器组成的元组。
示例
>>> import functools >>> >>> x = torch.rand(10, 10, requires_grad=True) >>> y = torch.rand(10, 10, requires_grad=True) >>> >>> ops_to_save = [ >>> torch.ops.aten.mm.default, >>> ] >>> >>> def policy_fn(ctx, op, *args, **kwargs): >>> if op in ops_to_save: >>> return CheckpointPolicy.MUST_SAVE >>> else: >>> return CheckpointPolicy.PREFER_RECOMPUTE >>> >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) >>> >>> # or equivalently >>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save) >>> >>> def fn(x, y): >>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y >>> >>> out = torch.utils.checkpoint.checkpoint( >>> fn, x, y, >>> use_reentrant=False, >>> context_fn=context_fn, >>> )