Fully Sharded Data Parallel

class torch.distributed.fsdp.FullyShardedDataParallel(module, process_group=None, sharding_strategy=None, cpu_offload=None, auto_wrap_policy=None, backward_prefetch=BackwardPrefetch.BACKWARD_PRE, mixed_precision=None, ignored_modules=None, param_init_fn=None, device_id=None, sync_module_states=False, forward_prefetch=False, limit_all_gathers=True, use_orig_params=False, ignored_states=None, device_mesh=None)[源代码]

用于在数据并行工作者之间分配模块参数的封装器。

这受到了Xu等人的工作以及DeepSpeed的ZeRO Stage 3的启发。FullyShardedDataParallel通常简写为FSDP。

要了解FSDP的内部机制,请参阅FSDP Notes

示例:

>>> import torch
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> torch.cuda.set_device(device_id)
>>> sharded_module = FSDP(my_module)
>>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
>>> x = sharded_module(x, y=3, z=torch.Tensor([1]))
>>> loss = x.sum()
>>> loss.backward()
>>> optim.step()

使用FSDP需要先包裹你的模块,然后再初始化优化器。这是因为FSDP会改变参数变量。

在设置FSDP时,需要考虑目标CUDA设备。如果该设备具有ID(dev_id),你有三种选择:

  • 将模块放在该设备上

  • 使用torch.cuda.set_device(dev_id)设置设备

  • dev_id作为device_id构造函数的参数传递。

这确保了FSDP实例的计算设备是目标设备。对于选项1和3,FSDP初始化始终在GPU上进行。而对于选项2,FSDP初始化则会在模块所在设备上执行,该设备可能是一块CPU。

如果你使用了 sync_module_states=True 标志,你需要确保模块运行在 GPU 上,或者使用 device_id 参数指定一个 CUDA 设备,FSDP 将在这个设备上移动模块。这是必要的,因为 sync_module_states=True 需要通过 GPU 进行通信。

FSDP 负责将输入张量从前向方法移动到 GPU 计算设备,因此你无需手动从 CPU 进行移动。

use_orig_params=True 时,ShardingStrategy.SHARD_GRAD_OP 暴露的是未分片的原始参数,而不是像 ShardingStrategy.FULL_SHARD 在前向传播后暴露的分片参数。如果你想检查梯度,可以使用带有 with_grads=Truesummon_full_params 方法。

当设置 limit_all_gathers=True 时,你可能会在 FSDP 预处理阶段看到 CPU 线程暂停发出内核的情况。这是有意为之,展示了速率限制器的效果。通过这种方式同步 CPU 线程可以防止为后续的所有收集操作过度分配内存,并且实际上不会延迟 GPU 内核的执行。

FSDP 在正向和反向计算期间,会用 torch.Tensor 视图替换管理模块的参数(出于自动微分相关的原因)。如果你的模块在前向传递中依赖于保存的参数引用而不是每次迭代都重新获取引用,则它将无法看到 FSDP 新创建的视图,从而导致自动微分无法正确工作。

最后,在使用 sharding_strategy=ShardingStrategy.HYBRID_SHARD 时,如果分片进程组在节点内而复制进程组在节点间,则对于某些集群配置,设置 NCCL_CROSS_NIC=1 可以帮助改善复制进程组中的 all-reduce 时间。

限制

在使用FSDP时,需要留意一些限制条件:

  • 当前,FSDP 在使用 CPU 卸载时不支持在 no_sync() 之外进行梯度累积。这是因为 FSDP 使用新计算的梯度而不是与现有梯度累加,这可能会导致错误的结果。

  • FSDP 不支持运行包含在其实例中的子模块的前向传递。这是因为在 FSDP 实例中,子模块的参数会被分片,而子模块本身不是 FSDP 实例,所以它的前向传递无法正确地聚集完整的参数。

  • 由于 FSDP 注册反向钩的方式,它不能与双倍反向传播一起使用。

  • FSDP 在冻结参数时有一些限制。当 use_orig_params=False 时,每个 FSDP 实例必须管理全部被冻结或全部未被冻结的参数。而当 use_orig_params=True 时,FSDP 支持混合使用冻结和未冻结的参数,但建议避免这样做以防止梯度内存使用超出预期。

  • 截至 PyTorch 1.12,FSDP 对共享参数提供了有限的支持。如果你的用例需要增强的共享参数支持,请在此问题链接中提出。

  • 你应该避免在前向传播和反向传播之间直接修改参数(不使用summon_full_params上下文),因为这样修改的参数可能不会被保存。

参数
  • module (nn.Module) – 需要使用 FSDP 包装的模块。

  • process_group (Optional[Union[ProcessGroup, Tuple[ProcessGroup, ProcessGroup]]]) – 这是模型被分片的进程组,也是 FSDP 集体通信(如 all-gather 和 reduce-scatter)所使用的进程组。如果为 None,则使用默认的进程组。对于混合分片策略(例如 ShardingStrategy.HYBRID_SHARD),用户可以传入一个包含两个进程组的元组,分别表示用于分片和复制的进程组。如果未指定,则 FSDP 会为用户构建用于节点内分片和节点间复制的进程组。(默认值: None

  • sharding_strategy (Optional[ShardingStrategy]) – 此配置用于设置分片策略,该策略可能在内存节省和通信开销之间进行权衡。有关详细信息,请参见ShardingStrategy。(默认值: FULL_SHARD)

  • cpu_offload (可选[CPUOffload]) – 用于配置CPU卸载。如果设置为None,则不进行任何CPU卸载操作。详情请参见CPUOffload。(默认值: None

  • auto_wrap_policy (Optional[Union[Callable[[nn.Module, bool, int]], bool], ModuleWrapPolicy, CustomPolicy]) – 自动包装策略,可选的联合类型,包括一个回调函数和布尔值、模块包装策略或自定义策略。

    这定义了一种策略,用于将 FSDP 应用于 module 的子模块,以实现通信和计算的重叠,并影响性能。如果设置为None,则 FSDP 仅应用于 module 自身,用户需要手动将 FSDP 应用于父模块(从下至上进行)。为了方便起见,这直接接受 ModuleWrapPolicy,允许用户指定要包装的模块类(例如变换器块)。否则,它应该是一个可调用对象,该对象接收三个参数:module: nn.Modulerecurse: boolnonwrapped_numel: int,并返回一个布尔值,指定是否将 FSDP 应用于传入的模块(如果 recurse=False)或继续遍历模块的子树(如果recurse=True)。用户可以向可调用对象添加额外参数。例如,在torch.distributed.fsdp.wrap.py中的 size_based_auto_wrap_policy 提供了一个示例,该策略会在模块的子树中参数超过 100M numel 时应用 FSDP。我们建议在应用 FSDP 后打印模型并根据需要进行调整。

    示例:

    >>> def custom_auto_wrap_policy(
    >>>     module: nn.Module,
    >>>     recurse: bool,
    >>>     nonwrapped_numel: int,
    >>>     # Additional custom arguments
    >>>     min_num_params: int = int(1e8),
    >>> ) -> bool:
    >>>     return nonwrapped_numel >= min_num_params
    >>> # Configure a custom `min_num_params`
    >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
    
  • backward_prefetch (Optional[BackwardPrefetch]) – 此配置用于显式地预取 all-gathers 的反向传播。如果为 None,则 FSDP 不会进行反向预取,在反向传递中没有通信和计算的重叠。有关详细信息,请参见BackwardPrefetch。(默认值:BACKWARD_PRE

  • mixed_precision (Optional[MixedPrecision]) – 此配置为 FSDP 设置原生混合精度。如果设置为 None,则不使用混合精度。否则,可以分别设置参数、缓冲区和梯度的减少数据类型。详情请参阅 MixedPrecision。(默认值: None)

  • ignored_modules (Optional[Iterable[torch.nn.Module]]): 被此实例忽略的模块,包括这些模块自身的参数、子模块的参数和缓冲区。在 ignored_modules 中直接包含的任何模块都不应是FullyShardedDataParallel 实例,如果这些实例已经构建为 FullyShardedDataParallel,则它们不会被忽略。当使用auto_wrap_policy 或参数的分片不由 FSDP 管理时,可以使用此参数避免在模块级别对特定参数进行分片。(默认值: None)

  • param_init_fn (Optional[Callable[ [nn.Module] , None ] ]) –

    A Callable[torch.nn.Module] -> None 用于指定当前位于元设备上的模块应该如何初始化到实际设备上。从 v1.12 版本开始,FSDP 使用 is_meta 检测具有参数或缓冲区的元设备模块,并在指定了的情况下应用 param_init_fn 或者调用 nn.Module.reset_parameters()。在这两种情况下,实现代码应仅初始化该模块的参数/缓冲区,而不包括其子模块的参数/缓冲区,以避免重复初始化。此外,FSDP 还支持通过 torchdistX (https://github.com/pytorch/torchdistX) 的 deferred_init() API 实现延迟初始化,在这种情况下,延迟模块将通过调用 param_init_fn 或者 torchdistX 默认的 materialize_module() 进行初始化。如果指定了 param_init_fn,则它会被应用于所有元设备模块,这意味着该函数可能需要根据模块类型进行处理。FSDP 在参数展平和分片之前调用初始化函数。

    示例:

    >>> module = MyModule(device="meta")
    >>> def my_init_fn(module: nn.Module):
    >>>     # E.g. initialize depending on the module type
    >>>     ...
    >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy)
    >>> print(next(fsdp_model.parameters()).device) # current CUDA device
    >>> # With torchdistX
    >>> module = deferred_init.deferred_init(MyModule, device="cuda")
    >>> # Will initialize via deferred_init.materialize_module().
    >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy)
    
  • device_id (可选[Union[int, torch.device]]) – 指定 FSDP 初始化所使用的 CUDA 设备,包括模块初始化和参数分片。如果 module 在 CPU 上,则应指定此参数以提高初始化速度。默认情况下,若已设置默认的 CUDA 设备(例如通过 torch.cuda.set_device),可以传递 torch.cuda.current_device。(默认值: None

  • sync_module_states (bool) – 如果为True,则每个 FSDP 模块将从 rank 0 广播模块参数和缓冲区,确保它们在所有 rank 上复制(增加此构造函数的通信开销)。这可以通过 load_state_dict 内存高效地加载 state_dict 检查点。有关示例,请参见FullStateDictConfig。(默认值:False)

  • forward_prefetch (bool) – 如果为 True,则 FSDP 在当前前向计算之前预取下一个前向传递的 all-gather。这仅对 CPU 绑定的工作负载有用,在这种情况下提前发出下一次 all-gather 可能会提高并行度。此选项仅适用于静态图模型,因为预取顺序遵循第一次迭代的执行顺序。(默认值: False)

  • limit_all_gathers (bool) – 如果设置为 True,FSDP 会显式同步 CPU 线程,确保 GPU 内存使用仅限于连续的两个 FSDP 实例(当前执行计算的实例和下一个预取 all-gather 的实例)。如果设置为 False,则 FSDP 允许 CPU 线程在没有额外同步的情况下发出 all-gathers。(默认值:True)我们通常称此功能为“速率限制器”。仅当特定工作负载以 CPU 为中心且内存压力较低时才应将其设置为 False,在这种情况下,CPU 线程可以积极地发出所有内核而不必担心 GPU 内存使用。

  • use_orig_params (bool) – 设置为True可以让FSDP使用模块的原始参数。FSDP通过nn.Module.named_parameters()向用户暴露这些原始参数,而不是内部的FlatParameter对象。这意味着优化器会在原始参数上运行,支持每个原始参数独立设置超参数。FSDP保留了原始参数变量,并在未分片和已分片形式之间操作其数据,它们始终是底层未分片或已分片FlatParameter的视图。使用当前算法时,分片形式总是1D,丢失了原始张量结构。对于给定的rank,一个原始参数可能具有全部、部分或没有其数据存在的情况。在没有任何数据的情况下,它的数据将类似于大小为0的空张量。用户不应编写依赖于给定原始参数在其分片形式中存在哪些数据的程序。True是使用torch.compile()所必需的。将其设置为False可以让FSDP内部的FlatParameter通过nn.Module.named_parameters()暴露给用户。(默认值:False)

  • ignored_states (Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]) – 忽略的参数或模块,这些参数和模块将不会由该 FSDP 实例管理。这意味着参数不会被分片,并且它们的梯度不会在不同 ranks 之间进行归约。此参数与现有的 ignored_modules 参数统一起来,我们可能会很快弃用 ignored_modules。为了向后兼容性,同时保留了 ignored_statesignored_modules` 两个参数,但 FSDP 只允许其中一个被指定为不为 None

  • device_mesh (可选[DeviceMesh]) – DeviceMesh 可以作为 process_group 的替代方案。当使用 device_mesh 时,FSDP 将利用底层进程组进行 all-gather 和 reduce-scatter 集体通信操作。因此,这两个参数需要互斥。对于混合分片策略(如 ShardingStrategy.HYBRID_SHARD),用户可以传递一个二维的 DeviceMesh 而不是进程组元组。对于 2D FSDP + TP 场景,必须使用 device_mesh 替代 process_group。更多关于 DeviceMesh 的信息,请访问:https://pytorch.org/tutorials/recipes/distributed_device_mesh.html

apply(fn)[源代码]

fn递归地应用到每个子模块(通过.children()获取)及其本身。

典型的使用场景包括初始化模型的参数(参见 torch.nn.init)。

torch.nn.Module.apply相比,这个版本会在应用fn之前先收集所有参数。它不应该在一个summon_full_params的上下文中被调用。

参数

fn (Module -> None) – 用于每个子模块的函数

返回值

自己

返回类型

模块

check_is_root()[源代码]

判断该实例是否是根FSDP模块。

返回类型

bool

clip_grad_norm_(max_norm, norm_type=2.0)[源代码]

限制所有参数的梯度范数。

范数是基于将所有参数的梯度视为一个整体向量来计算的,而梯度会在原地被修改。

参数
  • max_norm (floatint) – 梯度的最大值规范

  • norm_type (floatint) – 表示使用的 p-范数的类型。可以设置为 'inf' 来表示无穷范数。

返回值

将所有参数视为一个向量时的总范数。

返回类型

Tensor

如果每个 FSDP 实例都使用 NO_SHARD(即没有梯度在不同 ranks 之间分片),则你可以直接使用 torch.nn.utils.clip_grad_norm_()

如果至少有一些 FSDP 实例使用了分片策略(即不是 NO_SHARD),你应该使用此方法而不是 torch.nn.utils.clip_grad_norm_(),因为该方法能够处理梯度在不同 ranks 之间分片的情况。

返回的总范数将根据 PyTorch 的类型提升语义,采用所有参数和梯度中“最大”的数据类型。例如,如果所有参数和梯度都使用低精度的数据类型,则返回的范数也将是该低精度的数据类型;但如果至少有一个参数或梯度使用 FP32 类型,则返回的范数将为 FP32。

警告

因为它是基于集体通信的,所以需要在所有进程中进行调用。

staticflatten_sharded_optim_state_dict(sharded_optim_state_dict, model, optim)[源代码]

展平分片优化器的状态字典。

该 API 与 shard_full_optim_state_dict() 类似。唯一的区别是,输入的 sharded_optim_state_dict 应由 sharded_optim_state_dict() 提供。因此,在每个 rank 上都会进行 all-gather 调用以收集 ShardedTensor

参数
返回值

参见 shard_full_optim_state_dict()

返回类型

Dict[str, Any]

forward(*args, **kwargs)[源代码]

为包裹的模块运行前向传递,并插入特定于 FSDP 的前置和后置碎片化逻辑。

返回类型

Any

静态fsdp_modules(module, root_only=False)[源代码]

返回所有的嵌套FSDP实例。

这可能包括module本身,而且只有当root_only=True时才会包含FSDP根模块。

参数
  • module (torch.nn.Module) – 根模块,该模块可能是也可能是不是 FSDP 模块。

  • root_only (bool) – 是否仅返回 FSDP 根模块。默认值为 False

返回值

输入 module 中嵌套的 FSDP 模块。

返回类型

列表:FullyShardedDataParallel

staticfull_optim_state_dict(model, optim, optim_input=None, rank0_only=True, group=None)[源代码]

返回优化器的完整状态字典。

在 rank 0 上合并完整的优化器状态,并按照 torch.optim.Optimizer.state_dict() 的约定将其作为 dict 返回,即包含键 "state""param_groups"。将 model 中的 FSDP 模块中的扁平化参数映射回未扁平化的形式。

由于使用了集体通信,需要在所有进程中调用此函数。然而,如果设置 rank0_only=True,则只有 rank 0 进程会填充状态字典,其他所有进程返回一个空的 dict

与此相反,torch.optim.Optimizer.state_dict() 使用完整参数名称作为键,而不是参数 ID。

类似于torch.optim.Optimizer.state_dict(),优化器状态字典中的张量不会被复制,因此可能会出现别名带来的意外情况。为了最佳实践,建议立即保存返回的优化器状态字典,例如使用torch.save()

参数
  • model (torch.nn.Module) – 根模块,该模块可能是也可能不是FullyShardedDataParallel 实例,其参数被传递给优化器 optim

  • optim (torch.optim.Optimizer) – model 参数的优化器。

  • optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – 传递给优化器 optim 的输入,表示参数组的列表或可迭代的参数;如果为 None,则默认输入为 model.parameters()。此参数已弃用,无需再传入。(默认值: None

  • rank0_only (bool) – 如果为True,仅在 rank 0 上保存填充的 dict; 如果为False,则在所有 ranks 上保存。 (默认值: True)

  • group (dist.ProcessGroup) – 模型的过程组,如果不使用默认过程组,则为None。(默认值: None

返回值

包含 model 原始未展平参数的优化器状态的 dict,遵循 torch.optim.Optimizer.state_dict() 的约定,包括“state”和“param_groups”键。如果 rank0_only=True,则非零 rank 返回一个空的 dict

返回类型

字典[str, Any]

staticget_state_dict_type(module)[源代码]

获取以 module 为根的 FSDP 模块的状态字典类型及其对应的配置。

目标模块不一定是FSDP模块。

返回值

一个包含当前设置的 StateDictSettings,包括 state_dict 类型以及 state_dict 和 optim_state_dict 的配置。

异常
  • AssertionError 如果 StateDictSettings 设置不同 –

  • FSDP 的子模块各不相同。

返回类型

StateDictSettings

属性 模块:Module

返回封装的模块。

named_buffers(*args, **kwargs)[源代码]

返回一个迭代器来遍历模块的缓冲区,并同时 yield 缓冲区的名字和其内容。

拦截缓冲区名称,并在使用summon_full_params()上下文管理器时,移除所有特定于FSDP的扁平化缓冲区前缀。

返回类型

Iterator[Tuple[str, Tensor]]

named_parameters(*args, **kwargs)[源代码]

返回一个迭代器,用于遍历模块参数,并同时 yield 参数的名称和参数本身。

summon_full_params()上下文中,拦截参数名称并移除所有特定于FSDP的扁平化参数前缀。

返回类型

Iterator[Tuple[str, Parameter]]

no_sync()[源代码]

禁止FSDP实例间的梯度同步。

在此上下文中,梯度将累积在模块变量中,并在退出上下文后通过首次正向和反向传递进行同步。此功能仅适用于根 FSDP 实例,并会递归地应用到所有子 FSDP 实例。

注意

这可能导致更高的内存使用,因为FSDP会累积完整的模型梯度(而不是梯度片段),直到最终同步。

注意

当与CPU卸载一起使用时,梯度在上下文管理器内部不会被卸载到CPU。相反,它们会在最终同步之后立即被卸载。

返回类型

生成器

静态optim_state_dict(model, optim, optim_state_dict=None, group=None)[源代码]

将与分片模型对应的优化器的状态字典进行转换。

给定的状态字典可以转换成以下三种类型之一:1)完整优化器状态字典,2)分片优化器状态字典,3)本地优化器状态字典。

对于完整的优化器状态字典,所有状态都进行了反序列化且未分片。可以通过state_dict_type()指定仅 Rank0 或仅使用 CPU 来避免内存溢出。

对于分片优化器状态字典,所有状态会被解嵌套但保持分片。通过state_dict_type()指定CPU模式来进一步节省内存。

对于本地 state_dict,不会进行任何转换。但是,状态会从 nn.Tensor 转换为 ShardedTensor,以表示其分片的特性(当前不支持此功能)。

示例:

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> from torch.distributed.fsdp import StateDictType
>>> from torch.distributed.fsdp import FullStateDictConfig
>>> from torch.distributed.fsdp import FullOptimStateDictConfig
>>> # Save a checkpoint
>>> model, optim = ...
>>> FSDP.set_state_dict_type(
>>>     model,
>>>     StateDictType.FULL_STATE_DICT,
>>>     FullStateDictConfig(rank0_only=False),
>>>     FullOptimStateDictConfig(rank0_only=False),
>>> )
>>> state_dict = model.state_dict()
>>> optim_state_dict = FSDP.optim_state_dict(model, optim)
>>> save_a_checkpoint(state_dict, optim_state_dict)
>>> # Load a checkpoint
>>> model, optim = ...
>>> state_dict, optim_state_dict = load_a_checkpoint()
>>> FSDP.set_state_dict_type(
>>>     model,
>>>     StateDictType.FULL_STATE_DICT,
>>>     FullStateDictConfig(rank0_only=False),
>>>     FullOptimStateDictConfig(rank0_only=False),
>>> )
>>> model.load_state_dict(state_dict)
>>> optim_state_dict = FSDP.optim_state_dict_to_load(
>>>     model, optim, optim_state_dict
>>> )
>>> optim.load_state_dict(optim_state_dict)
参数
  • model (torch.nn.Module) – 根模块,该模块可能是也可能不是FullyShardedDataParallel 实例,其参数被传递给优化器 optim

  • optim (torch.optim.Optimizer) – model 参数的优化器。

  • optim_state_dict (Dict[str, Any]) – 目标优化器状态字典,用于转换。如果值为 None,则使用 optim.state_dict()。(默认值: None)

  • group (dist.ProcessGroup) – 模型的进程组,参数在其上进行分片处理;如果使用默认进程组,则为None。 ( 默认值: None)

返回值

包含model优化器状态的dict。优化器状态的分片依据state_dict_type

返回类型

字典[str, Any]

staticoptim_state_dict_to_load(model, optim, optim_state_dict, is_named_optimizer=False, load_directly=False, group=None)[源代码]

将优化器的状态字典转换为可以加载到与 FSDP 模型相关联的优化器中的格式。

给定一个通过optim_state_dict()转换的optim_state_dict,它会被转换为可以加载到optim中的扁平化优化器状态字典。其中,optimmodel的优化器,并且model必须通过FullyShardedDataParallel进行分片。

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> from torch.distributed.fsdp import StateDictType
>>> from torch.distributed.fsdp import FullStateDictConfig
>>> from torch.distributed.fsdp import FullOptimStateDictConfig
>>> # Save a checkpoint
>>> model, optim = ...
>>> FSDP.set_state_dict_type(
>>>     model,
>>>     StateDictType.FULL_STATE_DICT,
>>>     FullStateDictConfig(rank0_only=False),
>>>     FullOptimStateDictConfig(rank0_only=False),
>>> )
>>> state_dict = model.state_dict()
>>> original_osd = optim.state_dict()
>>> optim_state_dict = FSDP.optim_state_dict(
>>>     model,
>>>     optim,
>>>     optim_state_dict=original_osd
>>> )
>>> save_a_checkpoint(state_dict, optim_state_dict)
>>> # Load a checkpoint
>>> model, optim = ...
>>> state_dict, optim_state_dict = load_a_checkpoint()
>>> FSDP.set_state_dict_type(
>>>     model,
>>>     StateDictType.FULL_STATE_DICT,
>>>     FullStateDictConfig(rank0_only=False),
>>>     FullOptimStateDictConfig(rank0_only=False),
>>> )
>>> model.load_state_dict(state_dict)
>>> optim_state_dict = FSDP.optim_state_dict_to_load(
>>>     model, optim, optim_state_dict
>>> )
>>> optim.load_state_dict(optim_state_dict)
参数
  • model (torch.nn.Module) – 根模块,该模块可能是也可能不是FullyShardedDataParallel 实例,其参数被传递给优化器 optim

  • optim (torch.optim.Optimizer) – model 参数的优化器。

  • optim_state_dict (Dict[str, Any]) – 需要加载的优化器状态信息。

  • is_named_optimizer (bool) – 表示该优化器是否为 NamedOptimizer 或 KeyedOptimizer。仅在 optim 是 TorchRec 的 KeyedOptimizer 或 torch.distributed 的 NamedOptimizer 时,此值才设置为 True。

  • load_directly (bool) – 如果设置为 True,则此 API 在返回结果前会调用 optim.load_state_dict(result)。否则,用户需要自行调用 optim.load_state_dict() (默认值:False

  • group (dist.ProcessGroup) – 模型的进程组,参数在其上进行分片处理;如果使用默认进程组,则为None。 ( 默认值: None)

返回类型

Dict[str, Any]

register_comm_hook(state, hook)[源代码]

注册一个通讯钩子。

这是一个增强功能,为用户提供了一个灵活的挂钩,允许他们在多个工作进程中指定如何聚合梯度。此挂钩可以用于实现多种算法,例如GossipGrad 和涉及不同通信策略的梯度压缩算法,在使用FullyShardedDataParallel 进行训练时。

警告

FSDP 通信钩子应在运行初始前向传递之前注册,且仅需注册一次。

参数
  • state (对象) –

    在训练过程中,钩子会接收一些信息来维护状态。例如,在梯度压缩中传递误差反馈,在GossipGrad中确定下一个通信的对等节点等等。这些信息由每个工作者本地存储,并在该工作者的所有梯度张量之间共享。

  • hook (Callable) – 可调用对象,具有以下签名之一:1) hook: Callable[torch.Tensor] -> None: 此函数接收一个 Python 张量,该张量表示相对于此 FSDP 单元包装的模型(不被其他 FSDP 子单元包装)的所有变量的完整、展平且未分片的梯度。然后执行所有必要的处理并返回 None; 2) hook: Callable[torch.Tensor, torch.Tensor] -> None: 此函数接收两个 Python 张量,第一个张量表示相对于此 FSDP 单元包装的模型(不被其他 FSDP 子单元包装)的所有变量的完整、展平且未分片的梯度。第二个张量表示一个预分配大小的张量,在减少后存储分片梯度的一部分。在这两种情况下,可调用对象执行所有必要的处理并返回 None。具有签名 1 的可调用对象预计要处理NO_SHARD情况下的梯度通信。具有签名 2 的可调用对象预计要处理分片情况下的梯度通信。

静态rekey_optim_state_dict(optim_state_dict, optim_state_key_type, model, optim_input=None, optim=None)[源代码]

将优化器状态字典 optim_state_dict 的键重新设置为 optim_state_key_type 类型。

这可以用来实现包含和不包含FSDP实例的模型之间优化器状态字典的兼容性。

要将 FSDP 完整优化器状态字典(即从 full_optim_state_dict())重新键入为参数 ID,并使其可以加载到未包装的模型中:

>>> wrapped_model, wrapped_optim = ...
>>> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim)
>>> nonwrapped_model, nonwrapped_optim = ...
>>> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model)
>>> nonwrapped_optim.load_state_dict(rekeyed_osd)

要将普通优化器的状态字典从非包装模型转换为可加载到包装模型中的格式:

>>> nonwrapped_model, nonwrapped_optim = ...
>>> osd = nonwrapped_optim.state_dict()
>>> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model)
>>> wrapped_model, wrapped_optim = ...
>>> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model)
>>> wrapped_optim.load_state_dict(sharded_osd)
返回值

优化器状态字典根据 optim_state_key_type 指定的参数键进行了重新键化。

返回类型

字典[str, Any]

staticscatter_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None, group=None)[源代码]

将 rank 0 的完整优化器状态字典分散到所有其他 rank。

返回每个排名上的分片优化器状态字典。返回值与shard_full_optim_state_dict()相同,在排名0上,第一个参数应为full_optim_state_dict()的返回值。

示例:

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> model, optim = ...
>>> full_osd = FSDP.full_optim_state_dict(model, optim)  # only non-empty on rank 0
>>> # Define new model with possibly different world size
>>> new_model, new_optim, new_group = ...
>>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group)
>>> new_optim.load_state_dict(sharded_osd)

注意

可以使用shard_full_optim_state_dict() 或者 scatter_full_optim_state_dict() 来获取要加载的分片优化器状态字典。假设完整的优化器状态字典位于 CPU 内存中,前者要求每个进程在 CPU 内存中有完整的字典,并且每个进程单独地对字典进行分片而无需通信;后者只需要 rank 0 在 CPU 内存中有完整的字典,并将每个分片移动到 GPU 内存(用于 NCCL),并将其适当地传递给其他进程。因此,前者具有更高的总 CPU 内存成本,而后者具有更高的通信成本。

参数
  • full_optim_state_dict (Optional[Dict[str, Any]]]) – 与未展平的参数对应的优化器状态字典,在 rank 0 上包含完整的非碎片化优化器状态;在非零 rank 上忽略此参数。

  • model (torch.nn.Module) – 根模块,该模块可能是也可能不是FullyShardedDataParallel 实例,其参数与full_optim_state_dict 中的优化器状态相对应。

  • optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – 传递给优化器的输入,表示参数组列表或可迭代的参数;如果为 None,则此方法假设输入为 model.parameters()。该参数已弃用,无需再传入。(默认值: None

  • optim (Optional[torch.optim.Optimizer]) – 优化器将加载此方法返回的状态字典。这是推荐使用的参数,优先于 optim_input。(默认值: None

  • group (dist.ProcessGroup) – 模型的过程组,如果不使用默认过程组,则为None。(默认值: None

返回值

完整的优化器状态字典现在被重新映射为扁平化参数,而非原来的非扁平化参数,并且只包括该排名下的优化器状态部分。

返回类型

字典[str, Any]

静态set_state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)[源代码]

为目标模块的所有后代FSDP模块设置state_dict_type

还支持(可选的)模型和优化器状态字典的配置。目标模块不一定是FSDP模块。如果是FSDP模块,其state_dict_type也会被更改。

注意

应仅对此 API 的顶级(根)模块进行调用。

注意

此 API 允许用户透明地使用传统的 state_dict API,在根 FSDP 模块被另一个 nn.Module 包装的情况下保存模型检查点。例如,以下代码将确保对所有非 FSDP 实例调用 state_dict,而对于 FSDP 实例则调用sharded_state_dict实现:

示例:

>>> model = DDP(FSDP(...))
>>> FSDP.set_state_dict_type(
>>>     model,
>>>     StateDictType.SHARDED_STATE_DICT,
>>>     state_dict_config = ShardedStateDictConfig(offload_to_cpu=True),
>>>     optim_state_dict_config = OptimStateDictConfig(offload_to_cpu=True),
>>> )
>>> param_state_dict = model.state_dict()
>>> optim_state_dict = FSDP.optim_state_dict(model, optim)
参数
  • module (torch.nn.Module) – 根模块。

  • state_dict_type (StateDictType) – 指定要设置的 state_dict_type 类型。

  • state_dict_config (Optional[StateDictConfig]) – 目标 state_dict_type 的配置。

  • optim_state_dict_config (Optional[OptimStateDictConfig]) – 用于优化器状态字典的配置。

返回值

包含模块之前的 state_dict 类型和配置的 StateDictSettings。

返回类型

StateDictSettings

staticshard_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None)[源代码]
将一个完整的优化器状态字典进行分割。

full_optim_state_dict中的状态重映射为扁平化参数,而不是非扁平化参数,并且仅限于该 rank 的优化器状态部分。第一个参数应该是full_optim_state_dict()的返回值。

示例:

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> model, optim = ...
>>> full_osd = FSDP.full_optim_state_dict(model, optim)
>>> torch.save(full_osd, PATH)
>>> # Define new model with possibly different world size
>>> new_model, new_optim = ...
>>> full_osd = torch.load(PATH)
>>> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model)
>>> new_optim.load_state_dict(sharded_osd)

注意

可以使用shard_full_optim_state_dict() 或者 scatter_full_optim_state_dict() 来获取要加载的分片优化器状态字典。假设完整的优化器状态字典位于 CPU 内存中,前者要求每个进程在 CPU 内存中有完整的字典,并且每个进程单独地对字典进行分片而无需通信;后者只需要 rank 0 在 CPU 内存中有完整的字典,并将每个分片移动到 GPU 内存(用于 NCCL),并将其适当地传递给其他进程。因此,前者具有更高的总 CPU 内存成本,而后者具有更高的通信成本。

参数
  • full_optim_state_dict (Dict[str, Any]) – 与未展平参数相对应的优化器状态字典,包含完整的非分片优化器状态。

  • model (torch.nn.Module) – 根模块,该模块可能是也可能不是FullyShardedDataParallel 实例,其参数与full_optim_state_dict 中的优化器状态相对应。

  • optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – 传递给优化器的输入,表示参数组列表或可迭代的参数;如果为 None,则此方法假设输入为 model.parameters()。该参数已弃用,无需再传入。(默认值: None

  • optim (Optional[torch.optim.Optimizer]) – 优化器将加载此方法返回的状态字典。这是推荐使用的参数,优先于 optim_input。(默认值: None

返回值

完整的优化器状态字典现在被重新映射为扁平化参数,而非原来的非扁平化参数,并且只包括该排名下的优化器状态部分。

返回类型

字典[str, Any]

静态sharded_optim_state_dict(model, optim, group=None)[源代码]

以分片形式返回优化器的状态字典。

该 API 与 full_optim_state_dict() 类似,但此 API 将所有非零维度的状态分割为 ShardedTensor 以节省内存。此 API 应仅在模型的 state_dict 使用上下文管理器 with state_dict_type(SHARDED_STATE_DICT): 派生时使用。

关于详细的使用方法,请参阅full_optim_state_dict()

警告

返回的状态字典包含了ShardedTensor,因此不能直接被常规的optim.load_state_dict使用。

返回类型

Dict[str, Any]

静态state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)[源代码]

为目标模块的所有后代FSDP模块设置state_dict_type

这个上下文管理器与set_state_dict_type()具有相同的功能。请参阅set_state_dict_type()的文档以获取详细信息。

示例:

>>> model = DDP(FSDP(...))
>>> with FSDP.state_dict_type(
>>>     model,
>>>     StateDictType.SHARDED_STATE_DICT,
>>> ):
>>>     checkpoint = model.state_dict()
参数
  • module (torch.nn.Module) – 根模块。

  • state_dict_type (StateDictType) – 指定要设置的 state_dict_type 类型。

  • state_dict_config (Optional[StateDictConfig]) – 模型的 state_dict 配置,用于指定目标 state_dict_type

  • optim_state_dict_config (Optional[OptimStateDictConfig]) – 目标 state_dict_type 的优化器 state_dict 配置。

返回类型

生成器

静态方法summon_full_params(module, recurse=True, writeback=True, rank0_only=False, offload_to_cpu=False, with_grads=False)[源代码]

使用此上下文管理器为FSDP实例暴露全部参数。

在模型完成前向和反向传播之后,可以使用它来获取参数以进行额外处理或检查。它可以接受一个非FSDP模块,并根据recurse 参数的不同,加载所有包含的FSDP模块及其子模块的完整参数。

注意

这可以应用于内部的FSDP。

注意

不能用于前向或后向传递中,也不能从该上下文中启动前向或后向传递。

注意

参数将在上下文管理器退出后恢复到本地分片,存储行为与前向传播时一致。

注意

可以修改所有参数,但在上下文管理器退出后,只有本地参数分片的部分会持久化(除非指定 writeback=False,此时更改将被丢弃)。在 FSDP 不对参数进行分片的情况下,目前仅当 world_size == 1 或使用 NO_SHARD 配置时,修改会持久化,无论 writeback 如何设置。

注意

此方法适用于不是 FSDP 模块本身的模块,但这些模块可能包含多个独立的 FSDP 单元。在这种情况下,给定的参数将应用于所有的 FSDP 单元。

警告

请注意,当前不支持同时使用 rank0_only=Truewriteback=True,这将引发错误。原因是在此上下文中,不同 rank 的模型参数形状会有所不同,而在退出上下文时写入这些参数会导致跨 rank 不一致。

警告

请注意,offload_to_cpurank0_only=False 会导致在同一台机器上的 GPU 将完整的参数冗余地复制到 CPU 内存中,从而可能引发 CPU 内存不足的问题。建议使用 offload_to_cpu 并设置 rank0_only=True

参数
  • recurse (bool, Optional) – 是否递归地为嵌套的FSDP实例获取所有参数(默认值:True)。

  • writeback (bool, Optional) – 如果为 False,在上下文管理器退出后对参数的修改将被丢弃;启用此功能可以稍微提高效率(默认值:True)

  • rank0_only (bool, Optional) – 如果为 True,完整的参数仅在全局 rank 0 上进行实例化。这意味着在此上下文中,只有 rank 0 将具有完整的参数,而其他 rank 则拥有分片的参数。需要注意的是,在同时设置 rank0_only=Truewriteback=True 的情况下是不支持的,因为在上下文内模型参数形状在不同 rank 上会有所不同,这可能导致退出上下文时不同 rank 之间的数据不一致。

  • offload_to_cpu (bool, Optional) – 如果为 True,则将完整参数卸载到 CPU。需要注意的是,当前仅在参数被分片时(除了 world_size = 1 或 NO_SHARD 配置外)才会发生卸载操作。建议使用 offload_to_cpu 并设置 rank0_only=True,以避免将模型参数的冗余副本卸载到同一 CPU 内存中。

  • with_grads (bool, Optional) – 如果为 True,梯度也会与参数一起解除分片。目前,这仅在将 use_orig_params=True 传递给 FSDP 构造函数且将 offload_to_cpu=False 传递给此方法时受支持。(默认值:False

返回类型

生成器

torch.distributed.fsdp.BackwardPrefetch(value)[源代码]

这配置了显式的反向预取,通过在反向传递中启用通信和计算的重叠来提高吞吐量,但会稍微增加内存使用。

  • BACKWARD_PRE: 这个选项提供了最大的重叠,但会增加最多的内存使用。它在计算当前参数的梯度之前预取下一组参数。这样可以使得下一个 all-gather 和当前的梯度计算同时进行,在峰值时,它会在内存中保存当前的参数集、下一组参数和当前的梯度集合。

  • BACKWARD_POST: 这种模式减少了重叠,但需要较少的内存使用。它在完成当前参数的梯度计算后预取下一组参数。这种模式使reduce-scatter与下一个梯度计算重叠,并且在为下一组参数分配内存之前释放当前参数集,在峰值时仅保留下一组参数和当前参数的梯度。

  • FSDP 的 backward_prefetch 参数可以设置为 None,这会完全禁用反向预取功能。这样做不会增加内存使用量或产生重叠。通常我们不推荐此设置,因为它可能会显著降低吞吐量。

从技术角度来讲:对于使用NCCL后端的单个进程组,即使从不同的流中发起,任何集合操作都会争夺同一设备上的NCCL流。这意味着集合操作的相对发起顺序会影响它们之间的重叠。两个反向预取值对应于不同的发起顺序。

torch.distributed.fsdp.ShardingStrategy(value)[源代码]

这定义了FullyShardedDataParallel在分布式训练中使用的分片策略。

  • FULL_SHARD: 参数、梯度和优化器状态都被分片处理。具体来说,参数在前向传播之前进行合并(使用 all-gather),在前向传播之后重新分片;反向计算之前再次合并,在反向计算之后重新分片。对于梯度,在反向计算后同步并分片(通过 reduce-scatter)。每个 rank 的优化器状态会独立更新。

  • SHARD_GRAD_OP: 在计算过程中,梯度和优化器状态会被分片,并且参数在计算外部也会被分片。对于参数,此策略会在前向传播之前合并,在前向传播之后不重新分片,仅在后向计算之后进行重新分片。每个进程会独立更新分片的优化器状态。在 no_sync() 内部,参数不会在后向计算之后被重新分片。

  • NO_SHARD: 参数、梯度和优化器状态不会被分片,而是像 PyTorch 的 DistributedDataParallel API 一样在各个 ranks 上进行复制。对于梯度,在反向计算后会通过 all-reduce 进行同步。未分片的优化器状态会在每个 rank 上进行本地更新。

  • HYBRID_SHARD: 在一个节点内应用 FULL_SHARD,并在节点之间复制参数。这样可以减少通信量,因为昂贵的 all-gathers 和 reduce-scatters 只在节点内部执行,对于中等规模的模型来说性能可能更好。

  • _HYBRID_SHARD_ZERO2: 在节点内应用SHARD_GRAD_OP,并在节点之间复制参数。这类似于HYBRID_SHARD,但可以提供更高的吞吐量,因为在前向传递之后未分片的参数不会被释放,从而节省了预反向中的all-gathers。

classtorch.distributed.fsdp.MixedPrecision(param_dtype=None, reduce_dtype=None, buffer_dtype=None, keep_low_precision_grads=False, cast_forward_inputs=False, cast_root_forward_inputs=True, _module_classes_to_ignore=(<class 'torch.nn.modules.batchnorm._BatchNorm'>, ))[源代码]

这配置了FSDP原生的混合精度训练。

变量
  • param_dtype (Optional[torch.dtype]) – 这指定了在前向和反向传播过程中模型参数的数据类型,因此也决定了计算的前向和反向数据类型。在外向前向和反向传播时(例如优化器步骤),分片 参数保持全精度,并且对于模型检查点,参数始终以全精度保存。(默认值: None)

  • reduce_dtype (Optional[torch.dtype]) – 指定梯度减少(如 reduce-scatter 或 all-reduce)的数据类型。如果此参数为 Noneparam_dtype 不是 None,则使用 param_dtype 的值,并以低精度执行梯度减少操作。该数据类型可以与 param_dtype 不同,例如强制在全精度下进行梯度减少。(默认值: None)

  • buffer_dtype (Optional[torch.dtype]) – 指定缓冲区的数据类型。FSDP 不会对缓冲区进行分片处理,而是在第一次前向传递时将它们转换为 buffer_dtype 类型,并在此之后保持该数据类型不变。对于模型检查点保存,除了 LOCAL_STATE_DICT 之外,所有缓冲区将以全精度形式保存。(默认值: None)

  • keep_low_precision_grads (bool) – 如果为False,FSDP会在反向传播后将梯度上转换为全精度以准备优化器步骤。如果为True,则FSDP会保持用于梯度减少的数据类型中的梯度,这可以节省内存,特别是在使用支持低精度运行的自定义优化器时。(默认值:False

  • cast_forward_inputs (bool) – 如果为True,则此 FSDP 模块将其前向传递的参数和关键字参数转换为 param_dtype。这样可以确保参数与输入的数据类型匹配,以满足许多操作的要求。当仅对某些但不是所有 FSDP 模块应用混合精度时,可能需要设置为True,在这种情况下,一个使用混合精度的 FSDP 子模块需要重新转换其输入。(默认值: False)

  • cast_root_forward_inputs (bool) – 如果为 True,根 FSDP 模块会将其前向传递的参数和关键字参数转换为 param_dtype 类型,从而覆盖 cast_forward_inputs 的值。对于非根 FSDP 模块,此设置不会产生任何影响。(默认值:True

  • _module_classes_to_ignore (Sequence[Type[nn.Module]]) – 这指定了在使用 auto_wrap_policy 时要忽略的混合精度模块类。这些类的模块将分别应用 FSDP,并禁用混合精度(这意味着最终的 FSDP 构造会偏离指定策略)。如果未设置 auto_wrap_policy,则此设置无效。该 API 是实验性的,并可能发生变化。(默认值:(_BatchNorm,))

注意

此 API 仍在试验中,可能随时发生变化。

注意

只有浮点张量会转换为指定的数据类型。

注意

summon_full_params中,参数会被强制使用全精度,而缓冲区则不会。

注意

即使层归一化和批归一化的输入精度较低(如float16bfloat16),它们仍然会在float32中进行累积。禁用FSDP的混合精度仅意味着这些归一化模块中的仿射参数保持在float32中。然而,这会导致这些归一化模块执行单独的all-gathers和reduce-scatters操作,可能会降低效率。因此,如果工作负载允许的话,用户应更倾向于将混合精度应用于这些模块。

注意

默认情况下,如果用户传递一个包含任何 _BatchNorm 模块的模型,并且指定了 auto_wrap_policy,那么批处理归一化模块将分别应用 FSDP 并禁用混合精度。有关详细信息,请参见 _module_classes_to_ignore 参数。

注意

MixedPrecision 默认情况下设置为 cast_root_forward_inputs=Truecast_forward_inputs=False。对于根 FSDP 实例,cast_root_forward_inputs 优先于 cast_forward_inputs。非根 FSDP 实例的 cast_root_forward_inputs 值将被忽略。默认设置适用于通常情况下每个 FSDP 实例具有相同的 MixedPrecision 配置,并且只需要在模型前向传递开始时将输入转换为 param_dtype 的情况。

注意

对于具有不同MixedPrecision配置的嵌套FSDP实例,我们建议为每个实例设置单独的cast_forward_inputs值,以决定在前向传递之前是否进行输入转换。在这种情况下,由于每次转换都在每个FSDP实例的前向传递之前发生,因此父级FSDP实例应在运行其FSDP子模块之前先运行非FSDP子模块,以避免激活数据类型因不同的MixedPrecision配置而发生变化。

示例:

>>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
>>> model[1] = FSDP(
>>>     model[1],
>>>     mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True),
>>> )
>>> model = FSDP(
>>>     model,
>>>     mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True),
>>> )

以上是一个工作的示例。另一方面,如果将 model[1] 替换为 model[0],意味着使用不同 MixedPrecision 的子模块先运行其前向计算,则 model[1] 会错误地看到 float16 激活值而不是 bfloat16 激活值。

torch.distributed.fsdp.CPUOffload(offload_params=False)[源代码]

这配置了CPU卸载功能。

变量

offload_params (bool) – 这个参数指定在不参与计算时是否将参数卸载到CPU。如果为 True,则还会将梯度卸载到CPU,这意味着优化器步骤将在CPU上运行。

torch.distributed.fsdp.StateDictConfig(offload_to_cpu=False)[源代码]

StateDictConfig 是所有 state_dict 配置类的基础类。用户应实例化一个子类(例如 FullStateDictConfig)来为 FSDP 支持的相应 state_dict 类型配置设置。

变量

offload_to_cpu (bool) – 如果为 True,则 FSDP 将状态字典值卸载到 CPU;如果为 False,则 FSDP 保持它们在 GPU 上。(默认值: False)

classtorch.distributed.fsdp.FullStateDictConfig(offload_to_cpu=False, rank0_only=False)[源代码]

FullStateDictConfig 是一个配置类,用于与 StateDictType.FULL_STATE_DICT 一起使用。我们建议在保存完整状态字典时启用 offload_to_cpu=True 来节省 GPU 内存,并启用 rank0_only=True 来节省 CPU 内存。此配置类应通过以下方式使用 state_dict_type() 上下文管理器:

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> fsdp = FSDP(model, auto_wrap_policy=...)
>>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
>>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg):
>>>     state = fsdp.state_dict()
>>>     # `state` will be empty on non rank 0 and contain CPU tensors on rank 0.
>>> # To reload checkpoint for inference, finetuning, transfer learning, etc:
>>> model = model_fn() # Initialize model in preparation for wrapping with FSDP
>>> if dist.get_rank() == 0:
>>>     # Load checkpoint only on rank 0 to avoid memory redundancy
>>>     state_dict = torch.load("my_checkpoint.pt")
>>>     model.load_state_dict(state_dict)
>>> # All ranks initialize FSDP module as usual. `sync_module_states` argument
>>> # communicates loaded checkpoint states from rank 0 to rest of the world.
>>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True)
>>> # After this point, all ranks have FSDP model with loaded checkpoint.
变量

rank0_only (bool) – 如果为True,则只有 rank 0 保存完整的状态字典,其他非零秩的节点保存空字典。如果为 False,则所有节点都保存完整的状态字典。(默认值:False

classtorch.distributed.fsdp.ShardedStateDictConfig(offload_to_cpu=False, _use_dtensor=False)[源代码]

ShardedStateDictConfig 是一个配置类,用于配合 StateDictType.SHARDED_STATE_DICT 使用。

变量

_use_dtensor (bool) – 如果为 True,则 FSDP 会将状态字典值保存为 DTensor;如果为 False,则会将其保存为 ShardedTensor。(默认: False

警告

_use_dtensorShardedStateDictConfig 的一个私有字段,用于确定状态字典值的类型。用户不应手动修改 _use_dtensor

torch.distributed.fsdp.LocalStateDictConfig(offload_to_cpu: bool = False)[源代码]
classtorch.distributed.fsdp.OptimStateDictConfig(offload_to_cpu=True)[源代码]

OptimStateDictConfig 是所有 optim_state_dict 配置类的基础类。用户应实例化一个子类(例如 FullOptimStateDictConfig)来配置 FSDP 支持的相应 optim_state_dict 类型的设置。

变量

offload_to_cpu (bool) – 如果设置为 True,FSDP 会将状态字典中的张量值移到 CPU 上;如果设置为 False,则 FSDP 保持这些张量在原始设备上(默认情况下是 GPU,除非启用了参数级的 CPU 卸载)。(默认值:True)

classtorch.distributed.fsdp.FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False)[源代码]
变量

rank0_only (bool) – 如果为True,则只有 rank 0 保存完整的状态字典,其他非零秩的节点保存空字典。如果为 False,则所有节点都保存完整的状态字典。(默认值:False

classtorch.distributed.fsdp.ShardedOptimStateDictConfig(offload_to_cpu=True, _use_dtensor=False)[源代码]

ShardedOptimStateDictConfig 是一个配置类,用于与 StateDictType.SHARDED_STATE_DICT 配合使用。

变量

_use_dtensor (bool) – 如果为 True,则 FSDP 会将状态字典值保存为 DTensor;如果为 False,则会将其保存为 ShardedTensor。(默认: False

警告

_use_dtensorShardedOptimStateDictConfig 的一个私有字段,用于确定状态字典值的类型。用户不应手动修改 _use_dtensor

torch.distributed.fsdp.LocalOptimStateDictConfig(offload_to_cpu: bool = False)[源代码]
```html classtorch.distributed.fsdp.StateDictSettings(state_dict_type: torch.distributed.fsdp.api.StateDictType, state_dict_config: torch.distributed.fsdp.api.StateDictConfig, optim_state_dict_config: torch.distributed.fsdp.api.OptimStateDictConfig)[source] ```
本页目录