FSDP 注意事项
FSDP 预取细微差别
为了在 forward
计算时重叠 forward
的全规约操作,有两种可能的机制:
-
隐式前向预取(始终启用)
-
显式前向预取(
forward_prefetch=True
)
隐式 forward
预取是指通过从单独的 CUDA 流中发出全规约操作,使得全规约操作可以与之前从 CPU 发出的 forward
计算重叠。例如,如果我们有层 0 全规约 -> 层 0 forward
计算 -> 层 1 全规约 -> …,那么层 1 全规约可以与层 0 forward
计算重叠,即使 CPU 线程是在之后发出的。(第一个全规约将无法与任何操作重叠。)
显式 forward
预取是指改变 CPU 线程的发出顺序,例如:层 0 全规约 -> 层 1 全规约 -> 层 0 forward
计算 -> …。在急切模式下,通常无法确定当前执行的层(例如层 0)的下一个层(例如层 1)。因此,显式 forward
预取仅适用于每次迭代执行顺序固定的模型(我们有时称之为“静态图”)。一个不满足此约束的模型示例是 FLAVA。
显式 forward
预取仅节省了发出一层的 forward
计算内核所需的时间,但缺点是在当前输出张量仍在使用时必须分配下一个 all-gather 的输出张量。通过在当前 forward
计算内核之前发出下一个 all-gather,可以在 GPU 上更早启动下一个 all-gather。对于大多数大语言模型(LLM)工作负载,这种情况不常见,因此没有启用 forward_prefetch=True
的动机。
相比之下,对于 backward
,我们必须使用显式的 backward
预取,否则通信和计算将完全不重叠。原因是我们在 all-gather 和 reduce-scatter 操作中使用了一个 NCCL 进程组(部分原因是早期的 NCCL 版本在同一设备上并行使用多个进程组是不安全的)。单个 NCCL 进程组意味着一个内部 NCCL 流,在该流上 all-gather 和 reduce-scatter 顺序执行。因此,除非我们显式地重新排序 CPU 发出的顺序为下一个 all-gather -> 当前的 reduce-scatter,否则当前的 reduce-scatter 将会阻塞下一个 all-gather,从而阻止下一个 backward
计算。这会导致当前的 reduce-scatter 无法与计算重叠。
通信负载大小
在 FSDP 中,通信包括:
- 在
forward
过程中对参数进行 all-gather 操作 - 在
backward
过程中对参数进行 all-gather 操作 - 在
backward
过程中对梯度进行 reduce-scatter 操作
如果使用激活检查点(checkpoint()
),则不会产生额外的通信,因为在 backward
过程中参数无论如何都会被预取。
在 FSDP 设计中,每个 rank 的通信负载确定如下:每次调用 FullyShardedDataParallel
时,都会创建一个通信组。该组包含 module.parameters()
中的参数,但不包括已分配给嵌套 FullyShardedDataParallel
实例的参数。例如,对于 Llama,如果你将 FullyShardedDataParallel
应用于每个变压器块以及根模块,那么每个变压器块将有一个通信组,最后还会有一个包含初始嵌入和最终线性的通信组。每个通信组对应一次 all-gather 操作和一次 reduce-scatter 操作。因此,FullyShardedDataParallel
的应用方式决定了通信的大小。一般来说,对于 LLMs,将 FSDP 应用于每个变压器块是一个很好的启发式方法。鉴于当前的设计,很难有更优的方案。
让我们考虑一个例子,假设我们有一个基于 Transformer 的模型,该模型分布在 8 个 GPU 上,切分仅在 Transformer 块级别进行,每个 Transformer 块包含 1.6B 参数,参数为 fp32(每个 4 字节)。这意味着切分后,每个 Transformer 块在每个节点上将包含 0.2B 参数。
-
在前向传播过程中,所有聚合通信将以
0.2*4 = 0.8GB
的块进行。 -
在反向传播过程中,每次通信将进行两次,每次
0.8GB
(一次所有聚合和一次减少分散)。
换句话说,将有 3 次通信,每次通信的负载为 0.8GB
。如果模型由 10 个 Transformer 块组成,则总共会有 30 次通信,总通信量为 30*0.8=24GB
。
为了形式化每节点每次通信的负载大小,公式为 total_transformer_block_params_in_B * dtype_bytes / num_gpus
(GB)。
请注意,在这个例子中,我们没有包括嵌入层所需的额外通信,这些通信也应计入总通信量。具体来说,计算取决于输入和输出嵌入是否绑定。如果它们不绑定,则通信次数会增加一倍。
FSDP 缓冲区大小
首先,我们来讨论为通信分配的缓冲区:
forward
当前需要 2 倍的全聚缓冲区大小。原因如下:
如 FSDP 预取细微差别 中所述,在显式 forward
预取(forward_prefetch=True
)的情况下,第 0 层全聚 -> 第 0 层前向计算 -> 第 1 层全聚 这个序列需要 2 个全聚大小的缓冲区,因为一个缓冲区在当前 forward
中使用,而另一个缓冲区用于预取。
尽管在隐式 forward
预取(forward_prefetch=False
,默认值)的情况下,理论上只需要 1 个缓冲区,但实际上仍然需要 2 倍的全聚大小的缓冲区。原因是,在扁平参数的 FSDP 设计中,我们不会从全聚缓冲区复制数据。用于计算的参数直接从全聚缓冲区中读取(这是‘扁平参数’设计的主要优势之一)。在这种情况下,当“第 1 层全聚”与“第 0 层前向计算”重叠时,“第 0 层前向计算”正在使用从“第 0 层全聚”缓冲区中读取的参数。
一个自然的问题是,什么时候你会希望 `forward_prefetch=False`?对于静态图模型(如大多数大语言模型),有一个主要的技术原因:在这些模型中,前向预取可能会导致不必要的内存开销和复杂性。实际上,我们为了某些依赖 CPU 的内部模型快速添加了这个选项,尚未在单元测试中全面验证,因此对其稳定性信心不足。`forward_prefetching=False` 可以稍微更容易理解,因为在这种情况下,我们不必检查记录的前向顺序作为可能的“失败模式”。此外,模块的所有聚合操作总是可以在其自己的 `record_function` 标签下的性能分析跟踪中找到。
`backward` 当前至少需要 2 倍的 all-gather 缓冲区大小,可能还需要更多。原因如下:
当前的 FSDP 设计使用 `recordStream` 来管理在一个流中产生并在另一个流中消费的分配,这可能导致比预期更多的内存使用。具体多出多少内存使用是不确定的,因为它取决于 GPU 内核与 CPU 之间的时序关系。`limit_all_gathers=True` 参数可以缓解这一问题。更多详细信息请参阅此讨论:[FSDP & CUDACachingAllocator](https://dev-discuss.pytorch.org/t/fsdp-cudacachingallocator-an-outsider-newb-perspective/1486/1)。
现有的 FSDP 与 autograd 的工作方式:
-
现有的 FSDP 会将每个
flat_param
(即 autograd 叶子节点)分别聚集。 -
它调用
torch.split
来获取flat_param
中对应于其组成原始参数的 1D 视图。 -
它对每个 1D 分割调用
torch.view
以重新视图为 ND。 -
这意味着在
backward
过程中,我们会遇到ViewBackward
(从 ND 到 1D)和SplitWithSizesBackward
(这是一个拼接操作)。具体来说,每个单独的梯度会被单独分配并计算,然后显式地进行拼接以构建 reduce-scatter 输入缓冲区。这实际上意味着在峰值内存点时,reduce-scatter 的缓冲区大小为 2 倍。
总结来说,对于 backward
,总的缓冲区大小大约是 reduce-scatter 缓冲区大小的 2 倍,再加上任何 recordStream
效果。
其次,让我们讨论额外的缓冲区:
一旦从所有 rank 聚集了分片参数,它们需要一个额外的缓冲区,总大小为 total_transformer_block_params_in_B * dtype_bytes
,用于存储完整的参数。继续之前的例子,如果每个 transformer 块有 1.6B 参数且参数为 fp32,则缓冲区大小为 1.6 * 4 = 6.4GB。
并且需要 2 个这样的缓冲区,因为当前有一个正在使用,另一个正在预取。
总结一下,我们有:
-
需要 2 个通信缓冲区
total_transformer_block_params_in_B*dtype_bytes/num_gpus
-
需要 2 个未分片的变压器块参数缓冲区
total_transformer_block_params_in_B*dtype_bytes
或者按照前面示例中的计算:
-
2*1.6*4/8=1.6GB
-
2*1.6*4=12.8GB
总计为 14.4GB
。
现在让我们简要讨论一下嵌入层的情况,因为之前在计算中没有包括它们:
根据我们讨论过的规则(即“通信缓冲区大小的确定方法”),我们可以进行如下分析:
-
假设我们将 FSDP 应用于根模块(例如
Transformer
类)。我们进一步将 FSDP 应用于每个 Transformer 块(例如TransformerBlock
类)。 -
最常见的情况是,embedding 层和最终的 linear 层是根
Transformer
类的直接子模块。 -
根据我们的规则,这意味着 embedding 层和最终的 linear 层被分配给根
Transformer
的扁平参数。 -
我们还有一个特殊规则,即根模块在前向传播后不会释放其参数,因为这些参数无论如何都会在反向传播时立即全部收集。
-
综上所述,这意味着根模块的扁平参数(包括 embedding 层和最终投影)在前向传播开始时会被全部收集,并一直保留在 GPU 内存中直到反向传播结束。
-
如果 embedding 层和最终 linear 层没有权重共享,我们可以进一步将 FSDP 应用于这两个层。对于权重共享的参数,我们需要它们属于同一个扁平参数,以避免重复计算。这将允许在前向传播使用后释放 embedding 层,并仅在反向传播接近结束时进行全部收集。
希望这能更好地说明——每个 FSDP 模块会将其参数分配到 module.parameters
中,但不包括已经分配给其他嵌套 FSDP 模块的参数,而 FSDP 模块的 forward
方法定义了其参数的‘活跃’时间段。因此,嵌套的 nn.Module
结构会影响所有gather和释放的调度,从而影响内存和吞吐量性能。