FSDP 注意事项

FSDP 预取细微差别

为了在 forward 计算时重叠 forward 的全规约操作,有两种可能的机制:

  1. 隐式前向预取(始终启用)

  2. 显式前向预取(forward_prefetch=True

隐式 forward 预取是指通过从单独的 CUDA 流中发出全规约操作,使得全规约操作可以与之前从 CPU 发出的 forward 计算重叠。例如,如果我们有层 0 全规约 -> 层 0 forward 计算 -> 层 1 全规约 -> …,那么层 1 全规约可以与层 0 forward 计算重叠,即使 CPU 线程是在之后发出的。(第一个全规约将无法与任何操作重叠。)

显式 forward 预取是指改变 CPU 线程的发出顺序,例如:层 0 全规约 -> 层 1 全规约 -> 层 0 forward 计算 -> …。在急切模式下,通常无法确定当前执行的层(例如层 0)的下一个层(例如层 1)。因此,显式 forward 预取仅适用于每次迭代执行顺序固定的模型(我们有时称之为“静态图”)。一个不满足此约束的模型示例是 FLAVA

显式 forward 预取仅节省了发出一层的 forward 计算内核所需的时间,但缺点是在当前输出张量仍在使用时必须分配下一个 all-gather 的输出张量。通过在当前 forward 计算内核之前发出下一个 all-gather,可以在 GPU 上更早启动下一个 all-gather。对于大多数大语言模型(LLM)工作负载,这种情况不常见,因此没有启用 forward_prefetch=True 的动机。

相比之下,对于 backward,我们必须使用显式的 backward 预取,否则通信和计算将完全不重叠。原因是我们在 all-gather 和 reduce-scatter 操作中使用了一个 NCCL 进程组(部分原因是早期的 NCCL 版本在同一设备上并行使用多个进程组是不安全的)。单个 NCCL 进程组意味着一个内部 NCCL 流,在该流上 all-gather 和 reduce-scatter 顺序执行。因此,除非我们显式地重新排序 CPU 发出的顺序为下一个 all-gather -> 当前的 reduce-scatter,否则当前的 reduce-scatter 将会阻塞下一个 all-gather,从而阻止下一个 backward 计算。这会导致当前的 reduce-scatter 无法与计算重叠。

通信负载大小

在 FSDP 中,通信包括:

  1. forward 过程中对参数进行 all-gather 操作
  2. backward 过程中对参数进行 all-gather 操作
  3. backward 过程中对梯度进行 reduce-scatter 操作

如果使用激活检查点(checkpoint()),则不会产生额外的通信,因为在 backward 过程中参数无论如何都会被预取。

在 FSDP 设计中,每个 rank 的通信负载确定如下:每次调用 FullyShardedDataParallel 时,都会创建一个通信组。该组包含 module.parameters() 中的参数,但不包括已分配给嵌套 FullyShardedDataParallel 实例的参数。例如,对于 Llama,如果你将 FullyShardedDataParallel 应用于每个变压器块以及根模块,那么每个变压器块将有一个通信组,最后还会有一个包含初始嵌入和最终线性的通信组。每个通信组对应一次 all-gather 操作和一次 reduce-scatter 操作。因此,FullyShardedDataParallel 的应用方式决定了通信的大小。一般来说,对于 LLMs,将 FSDP 应用于每个变压器块是一个很好的启发式方法。鉴于当前的设计,很难有更优的方案。

让我们考虑一个例子,假设我们有一个基于 Transformer 的模型,该模型分布在 8 个 GPU 上,切分仅在 Transformer 块级别进行,每个 Transformer 块包含 1.6B 参数,参数为 fp32(每个 4 字节)。这意味着切分后,每个 Transformer 块在每个节点上将包含 0.2B 参数。

  • 在前向传播过程中,所有聚合通信将以 0.2*4 = 0.8GB 的块进行。

  • 在反向传播过程中,每次通信将进行两次,每次 0.8GB(一次所有聚合和一次减少分散)。

换句话说,将有 3 次通信,每次通信的负载为 0.8GB。如果模型由 10 个 Transformer 块组成,则总共会有 30 次通信,总通信量为 30*0.8=24GB

为了形式化每节点每次通信的负载大小,公式为 total_transformer_block_params_in_B * dtype_bytes / num_gpus(GB)。

请注意,在这个例子中,我们没有包括嵌入层所需的额外通信,这些通信也应计入总通信量。具体来说,计算取决于输入和输出嵌入是否绑定。如果它们不绑定,则通信次数会增加一倍。

FSDP 缓冲区大小

首先,我们来讨论为通信分配的缓冲区:

forward 当前需要 2 倍的全聚缓冲区大小。原因如下:

FSDP 预取细微差别 中所述,在显式 forward 预取(forward_prefetch=True)的情况下,第 0 层全聚 -> 第 0 层前向计算 -> 第 1 层全聚 这个序列需要 2 个全聚大小的缓冲区,因为一个缓冲区在当前 forward 中使用,而另一个缓冲区用于预取。

尽管在隐式 forward 预取(forward_prefetch=False,默认值)的情况下,理论上只需要 1 个缓冲区,但实际上仍然需要 2 倍的全聚大小的缓冲区。原因是,在扁平参数的 FSDP 设计中,我们不会从全聚缓冲区复制数据。用于计算的参数直接从全聚缓冲区中读取(这是‘扁平参数’设计的主要优势之一)。在这种情况下,当“第 1 层全聚”与“第 0 层前向计算”重叠时,“第 0 层前向计算”正在使用从“第 0 层全聚”缓冲区中读取的参数。

一个自然的问题是,什么时候你会希望 `forward_prefetch=False`?对于静态图模型(如大多数大语言模型),有一个主要的技术原因:在这些模型中,前向预取可能会导致不必要的内存开销和复杂性。实际上,我们为了某些依赖 CPU 的内部模型快速添加了这个选项,尚未在单元测试中全面验证,因此对其稳定性信心不足。`forward_prefetching=False` 可以稍微更容易理解,因为在这种情况下,我们不必检查记录的前向顺序作为可能的“失败模式”。此外,模块的所有聚合操作总是可以在其自己的 `record_function` 标签下的性能分析跟踪中找到。

`backward` 当前至少需要 2 倍的 all-gather 缓冲区大小,可能还需要更多。原因如下:

当前的 FSDP 设计使用 `recordStream` 来管理在一个流中产生并在另一个流中消费的分配,这可能导致比预期更多的内存使用。具体多出多少内存使用是不确定的,因为它取决于 GPU 内核与 CPU 之间的时序关系。`limit_all_gathers=True` 参数可以缓解这一问题。更多详细信息请参阅此讨论:[FSDP & CUDACachingAllocator](https://dev-discuss.pytorch.org/t/fsdp-cudacachingallocator-an-outsider-newb-perspective/1486/1)。

现有的 FSDP 与 autograd 的工作方式:

  • 现有的 FSDP 会将每个 flat_param(即 autograd 叶子节点)分别聚集。

  • 它调用 torch.split 来获取 flat_param 中对应于其组成原始参数的 1D 视图。

  • 它对每个 1D 分割调用 torch.view 以重新视图为 ND。

  • 这意味着在 backward 过程中,我们会遇到 ViewBackward(从 ND 到 1D)和 SplitWithSizesBackward(这是一个拼接操作)。具体来说,每个单独的梯度会被单独分配并计算,然后显式地进行拼接以构建 reduce-scatter 输入缓冲区。这实际上意味着在峰值内存点时,reduce-scatter 的缓冲区大小为 2 倍。

总结来说,对于 backward,总的缓冲区大小大约是 reduce-scatter 缓冲区大小的 2 倍,再加上任何 recordStream 效果。

其次,让我们讨论额外的缓冲区:

一旦从所有 rank 聚集了分片参数,它们需要一个额外的缓冲区,总大小为 total_transformer_block_params_in_B * dtype_bytes,用于存储完整的参数。继续之前的例子,如果每个 transformer 块有 1.6B 参数且参数为 fp32,则缓冲区大小为 1.6 * 4 = 6.4GB。

并且需要 2 个这样的缓冲区,因为当前有一个正在使用,另一个正在预取。

总结一下,我们有:

  1. 需要 2 个通信缓冲区 total_transformer_block_params_in_B*dtype_bytes/num_gpus

  2. 需要 2 个未分片的变压器块参数缓冲区 total_transformer_block_params_in_B*dtype_bytes

或者按照前面示例中的计算:

  1. 2*1.6*4/8=1.6GB

  2. 2*1.6*4=12.8GB

总计为 14.4GB

现在让我们简要讨论一下嵌入层的情况,因为之前在计算中没有包括它们:

根据我们讨论过的规则(即“通信缓冲区大小的确定方法”),我们可以进行如下分析:

  • 假设我们将 FSDP 应用于根模块(例如 Transformer 类)。我们进一步将 FSDP 应用于每个 Transformer 块(例如 TransformerBlock 类)。

  • 最常见的情况是,embedding 层和最终的 linear 层是根 Transformer 类的直接子模块。

  • 根据我们的规则,这意味着 embedding 层和最终的 linear 层被分配给根 Transformer 的扁平参数。

  • 我们还有一个特殊规则,即根模块在前向传播后不会释放其参数,因为这些参数无论如何都会在反向传播时立即全部收集。

  • 综上所述,这意味着根模块的扁平参数(包括 embedding 层和最终投影)在前向传播开始时会被全部收集,并一直保留在 GPU 内存中直到反向传播结束。

  • 如果 embedding 层和最终 linear 层没有权重共享,我们可以进一步将 FSDP 应用于这两个层。对于权重共享的参数,我们需要它们属于同一个扁平参数,以避免重复计算。这将允许在前向传播使用后释放 embedding 层,并仅在反向传播接近结束时进行全部收集。

希望这能更好地说明——每个 FSDP 模块会将其参数分配到 module.parameters 中,但不包括已经分配给其他嵌套 FSDP 模块的参数,而 FSDP 模块的 forward 方法定义了其参数的‘活跃’时间段。因此,嵌套的 nn.Module 结构会影响所有gather和释放的调度,从而影响内存和吞吐量性能。

本页目录