分布式数据并行
警告
torch.nn.parallel.DistributedDataParallel
的实现会随着时间而演变。此设计说明是基于 v1.4 版本的状态编写的。
torch.nn.parallel.DistributedDataParallel
(DDP) 透明地执行分布式数据并行训练。本页详细介绍了其工作原理,并揭示了实现细节。
示例
我们从一个简单的 torch.nn.parallel.DistributedDataParallel
示例开始。此示例使用一个 torch.nn.Linear
作为本地模型,将其封装在 DDP 中,然后在 DDP 模型上执行一次前向传播、一次反向传播和一次优化器步骤。之后,本地模型的参数将被更新,不同进程中的所有模型应完全一致。
DDP 与 TorchDynamo 兼容。当与 TorchDynamo 一起使用时,在编译模型之前应用 DDP 模型包装器,以便 torchdynamo 可以根据 DDP 的分块大小应用 DDPOptimizer
(图断点优化)。有关更多信息,请参阅 TorchDynamo DDPOptimizer。
ddp_model = DDP(model, device_ids=[rank])
ddp_model = torch.compile(ddp_model)
内部机制
本节通过深入探讨每个迭代步骤的具体细节,揭示了 torch.nn.parallel.DistributedDataParallel
的内部工作原理。
- 前提条件: DDP 依赖于 c10d
ProcessGroup
进行通信。因此,应用程序必须在构建 DDP 之前创建ProcessGroup
实例。
*
-
前向传递: DDP 接收输入并将其传递给本地模型,然后在
find_unused_parameters
设置为True
时分析本地模型的输出。这种模式允许在模型的子图上进行计算,DDP 通过遍历自动梯度图来确定哪些参数参与了反向传播,并标记所有未使用的参数为准备好进行归约。在反向传播过程中,Reducer
只会等待未准备好的参数,但仍然会归约所有桶。目前,将参数梯度标记为准备好并不能帮助 DDP 跳过桶,但它可以防止 DDP 在反向传播过程中无限期地等待不存在的梯度。需要注意的是,遍历自动梯度图会引入额外的开销,因此应用程序只有在必要时才应将find_unused_parameters
设置为True
。 -
反向传递:
backward()
函数直接在损失Tensor
上调用,而 DDP 无法直接控制这一过程。DDP 使用在构造时注册的梯度挂钩来触发梯度同步。当某个梯度准备好时,其对应的 DDP 挂钩会触发,DDP 将该参数梯度标记为准备好进行归约。当一个桶中的所有梯度都准备好时,Reducer
会在该桶上启动异步allreduce
来计算所有进程之间的梯度均值。当所有桶都准备好时,Reducer
会阻塞等待所有allreduce
操作完成。完成后,平均梯度会被写入所有参数的param.grad
字段。因此,在反向传递之后,不同 DDP 进程中相同对应参数的梯度字段应该是相同的。 -
优化器步骤:从优化器的角度,它在优化一个本地模型。所有 DDP 进程中的模型副本可以保持同步,因为它们都从相同的状态出发,并且在每次迭代中都有相同的平均梯度值。
注意
DDP 要求所有进程上的 Reducer
实例必须以完全相同的顺序调用 allreduce
,这是通过始终按照桶索引顺序(而不是实际桶就绪顺序)来运行 allreduce
来实现的。如果跨进程的 allreduce
顺序不匹配,可能会导致错误的结果或 DDP 反向传播挂起。
实现
以下是 DDP 实现组件的指针。堆叠图显示了代码的结构。
ProcessGroup
-
ProcessGroup.hpp: 包含所有进程组实现的抽象 API。
c10d
库默认提供了三种实现:ProcessGroupGloo、ProcessGroupNCCL 和 ProcessGroupMPI。DistributedDataParallel
使用ProcessGroup::broadcast()
在初始化期间从 rank 0 的进程向其他进程发送模型状态,并使用ProcessGroup::allreduce()
求和梯度。 -
Store.hpp: 协助进程组实例的会面服务,使它们能够相互发现。
DistributedDataParallel
-
distributed.py: 是 DDP 的 Python 入口点。它实现了初始化步骤和
nn.parallel.DistributedDataParallel
模块的forward
函数,这些函数会调用 C++ 库。它的_sync_param
函数在单个 DDP 进程处理多个设备时执行进程内部参数同步,并从 rank 为 0 的进程广播模型缓冲区到所有其他进程。进程之间的参数同步在Reducer.cpp
中实现。 -
comm.h: 实现了合并广播辅助函数,该函数在初始化期间广播模型状态并在前向传播之前同步模型缓冲区时被调用。
-
reducer.h: 提供了反向传播中梯度同步的核心实现。它有三个入口点函数:
-
Reducer
: 构造函数在distributed.py
中被调用,注册Reducer::autograd_hook()
到梯度累加器(gradient accumulators)。 -
autograd_hook()
函数会在梯度准备好时由自动梯度引擎(autograd engine)调用。 -
prepare_for_backward()
在distributed.py
中 DDP 前向传播结束时被调用。当在 DDP 构造函数中将find_unused_parameters
设置为True
时,它会遍历自动梯度图(autograd graph)以查找未使用的参数。
-
TorchDynamo DDPOptimizer
DDP 的性能优势来自于在反向传播期间将 allreduce 集体操作与计算重叠。当使用 TorchDynamo 编译整个前向和后向图时,AotAutograd 会阻止这种重叠,因为 allreduce 操作是在整个优化后的反向计算完成后由 autograd 钩子触发的。
TorchDynamo 的 DDPOptimizer 通过在反向传播期间在 DDP 的 allreduce 桶的逻辑边界(即每个 allreduce 操作的起点)处拆分前向图来帮助解决这个问题。注意:目标是在反向传播期间拆分图,最简单的实现是在前向图中拆分,然后对每个部分调用 AotAutograd 和编译。这使得 DDP 的 allreduce 钩子可以在反向传播的不同部分之间触发,并安排通信与计算重叠。
请参阅 这篇博客文章,以获取更深入的解释和实验结果,或阅读 torch/_dynamo/optimizations/distributed.py 中的文档和代码。
要调试 DDPOptimizer,设置 TORCH_LOGS='ddp_graphs'
以获取完整的图转储。对于不包含图的日志,可以在 TORCH_LOGS
中添加 'dynamo'、'distributed' 或 'dist_ddp'(以获取关于桶边界的基信息)。要禁用 DDPOptimizer,请设置 torch._dynamo.config.optimize_ddp=False
。即使没有 DDPOptimizer,DDP 和 TorchDynamo 仍然可以正常工作,但性能会有所下降。