分布式自动求导设计

本文将介绍分布式自动微分的详细设计,并探讨其实现细节。在继续之前,请确保你已经熟悉了自动微分机制分布式RPC框架

背景

假设你有两个节点和一个非常简单的模型,这个模型分布在两个节点上。你可以使用torch.distributed.rpc 进行如下实现:

import torch
import torch.distributed.rpc as rpc

def my_add(t1, t2):
  return torch.add(t1, t2)

# On worker 0:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)

# Perform some computation remotely.
t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))

# Perform some computation locally based on remote result.
t4 = torch.rand((3, 3), requires_grad=True)
t5 = torch.mul(t3, t4)

# Compute some loss.
loss = t5.sum()

分布式自动微分的主要目的是在分布式模型上运行反向传播,基于我们计算出的损失,并为所有需要梯度的张量记录相应的梯度。

前向传播过程中的自动求导记录

PyTorch 在前向传播过程中构建自动求导图,并使用此图执行反向传播。更多细节请参见 自动微分如何编码历史记录

对于分布式自动微分,我们需要在前向传递期间跟踪所有RPC,以确保反向传递能够正确执行。为此,在执行RPC时,我们将sendrecv函数添加到自动微分图中。

  • 代码中的 send 函数被附加到 RPC 的源端,其输出连接到输入张量的自动微分函数。在反向传播过程中,此函数的输入是由目的地通过相应的 recv 函数发送过来的。

  • 函数 recv 附加在 RPC 的目标端,并从该端执行的操作中获取输入张量。在反向传递过程中,此函数的输出梯度会被发送回源节点并交给相应的 send 函数。

  • 每对 send-recv 都会被分配一个全局唯一的 autograd_message_id,用于唯一标识该对。这在反向传播过程中查找远程节点上的相应函数时非常有用。

  • 对于RRef,每当调用torch.distributed.rpc.RRef.to_here()时,我们会为涉及的张量附加适当的send-recv对。

例如,这是上述示例的自动微分图(为了简洁,未包含 t5.sum())的样子:

{BASE_RAW_UPLOAD_URL}/pytorch-doc-2.5/14f8eb84aa4d320d59e986489fe863cf.png

分布式自动求导上下文

每个使用分布式自动微分的前向和后向传递都会被分配一个唯一的torch.distributed.autograd.context,并且此上下文具有全局唯一的autograd_context_id。根据需要在每个节点上创建该上下文。

该上下文的目的如下:

  1. 运行分布式反向传播的多个节点可能会在同一张 tensor 上累积梯度,因此在有机会运行优化器之前,该 tensor 的 .grad 字段会包含来自不同分布式反向传递的梯度。这类似于多次本地调用torch.autograd.backward()。为了为每个反向传播分离梯度,会在每次反向传播的 torch.distributed.autograd.context 中累积梯度。

  2. 在前向传递过程中,我们为每次自动微分传递在此上下文中存储sendrecv函数。这确保了对自动微分图中适当节点的引用得以保留,使其保持活动状态。此外,在反向传递期间查找适当的sendrecv函数也非常简单。

  3. 通常,我们还使用此上下文为每个分布式自动微分传递存储一些元数据。





从用户的角度来看,autograd 的上下文设置如下:

import torch.distributed.autograd as dist_autograd
with dist_autograd.context() as context_id:
  loss = model.forward()
  dist_autograd.backward(context_id, loss)

需要注意的是,您的模型的前向传递必须在分布式自动微分上下文中进行调用。这是因为需要一个有效的上下文来确保所有的sendrecv函数被正确存储,以便在所有参与节点上顺利执行反向传递。

分布式 backward pass

在本节中,我们将概述在分布式反向传递过程中准确计算依赖关系的挑战,并介绍几种具有权衡的算法,以说明如何执行分布式反向传递。

计算依赖关系

考虑在单台机器上运行以下代码

import torch
a = torch.rand((3, 3), requires_grad=True)
b = torch.rand((3, 3), requires_grad=True)
c = torch.rand((3, 3), requires_grad=True)
d = a + b
e = b * c
d.sum.().backward()

这便是上述代码对应的自动微分图(autograd图)的样子:

{BASE_RAW_UPLOAD_URL}/pytorch-doc-2.5/cff1ca30881a1ae6d72499f5122cdd9b.png

自动微分引擎在反向传递过程中的第一步是计算自动微分图中每个节点的依赖项数量。这有助于自动微分引擎确定何时可以执行某个节点。add(1)mul(0) 中括号内的数字表示依赖项的数量。如你所见,在反向传递过程中,add 节点需要 1 个输入,而 mul 节点不需要任何输入(换句话说,无需执行)。局部自动微分引擎通过从根节点(在这种情况下为 d)开始遍历图来计算这些依赖项。

某些节点在反向传递过程中可能不会被执行,这给分布式自动微分带来了挑战。请看下面这段使用RPC的代码。

import torch
import torch.distributed.rpc as rpc

a = torch.rand((3, 3), requires_grad=True)
b = torch.rand((3, 3), requires_grad=True)
c = torch.rand((3, 3), requires_grad=True)

d = rpc.rpc_sync("worker1", torch.add, args=(a, b))
e = rpc.rpc_sync("worker1", torch.mul, args=(b, c))
loss = d.sum()

上述代码对应的自动微分图会是:

{BASE_RAW_UPLOAD_URL}/pytorch-doc-2.5/091cd2dd06ff9c391e040564724034d2.png

计算此分布式自动微分图的依赖关系更为复杂,需要一定的开销(无论是计算成本还是网络通信成本)。

对于性能敏感的应用程序,我们可以通过假设每个 sendrecv 函数在反向传递过程中都是有效的(大多数应用程序不会执行未使用的 RPC 操作)来避免很多开销。这简化了分布式自动微分算法,并且提高了效率,但代价是应用程序需要了解这些限制。该算法称为FAST模式算法,详情如下。

通常,并非每个 sendrecv 函数都适合作为反向传递的一部分。为解决此问题,我们提出了一种SMART 模式算法,该算法将在后续章节中详细介绍。目前仅实现了FAST模式算法。

FAST模式算法

该算法的核心假设是在进行反向传播时,每个 send 函数的依赖项为 1。换句话说,我们假定会通过远程过程调用(RPC)从另一个节点接收梯度。

算法如下:

  1. 我们从具有反向传播起点的工作者开始(所有的起点都必须是本地的)。

  2. 查找当前分布式自动微分上下文中的所有send函数。

  3. 从提供的根节点开始,计算所有检索到的send函数之间的依赖关系。

  4. 计算依赖关系后,使用提供的根节点启动本地自动微分引擎。

  5. 当自动微分引擎执行recv函数时,该函数通过RPC将输入的梯度发送给适当的工作者。每个recv函数都知道目标工作者ID,因为这个ID是在前向传递过程中记录下来的。recv函数还会将autograd_context_idautograd_message_id发送到远程主机。

  6. 当远程主机接收到此请求时,我们使用 autograd_context_idautograd_message_id 来查找合适的 send 函数。

  7. 如果这是第一次为给定的 autograd_context_id 分配任务,它将根据上述第 1-3 点所述进行本地依赖关系的计算。

  8. 在步骤 6 中检索到的 send 函数将会被加入队列,并在该工作线程的本地自动微分引擎上执行。

  9. 最后,我们不在 Tensor 的 .grad 字段上累积梯度,而是在每个 分布式自动微分上下文 中单独累积梯度。这些梯度存储在一个 Dict[Tensor, Tensor] 中,这是一个从 Tensor 到其相关梯度的映射,并且可以使用 get_gradients() API 获取该映射。





以下是一个带有分布式自动微分的完整代码示例:

import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc

def my_add(t1, t2):
  return torch.add(t1, t2)

# On worker 0:

# Setup the autograd context. Computations that take
# part in the distributed backward pass must be within
# the distributed autograd context manager.
with dist_autograd.context() as context_id:
  t1 = torch.rand((3, 3), requires_grad=True)
  t2 = torch.rand((3, 3), requires_grad=True)

  # Perform some computation remotely.
  t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))

  # Perform some computation locally based on remote result.
  t4 = torch.rand((3, 3), requires_grad=True)
  t5 = torch.mul(t3, t4)

  # Compute some loss.
  loss = t5.sum()

  # Run the backward pass.
  dist_autograd.backward(context_id, [loss])

  # Retrieve the gradients from the context.
  dist_autograd.get_gradients(context_id)

带依赖关系的分布式自动求导图如下所示(为了简化,省略了 t5.sum()):

{BASE_RAW_UPLOAD_URL}/pytorch-doc-2.5/83cbd093cb0536367fe00a07c5daa970.png

以下是在上述示例中应用的FAST 模式算法

  1. Worker 0上,我们从根节点losssend1开始计算依赖关系。结果是send1被标记为有一个依赖项,并且在Worker 0上的mul也被标记为有一个依赖项。

  2. 现在,我们在Worker 0启动本地自动微分引擎。首先执行mul函数,并将其输出累积在自动微分上下文中作为t4的梯度。然后,我们执行recv2,将梯度发送到Worker 1

  3. 因为这是Worker 1第一次听到关于这次反向传递的信息,它开始进行依赖性计算,并适当标记了send2addrecv1的依赖关系。

  4. 接下来,我们将send2加入到Worker 1的本地自动微分引擎中,这将执行addrecv1

  5. 当执行recv1时,它会将梯度发送给Worker 0

  6. 因为 Worker 0 已经为这次反向传播计算了依赖关系,所以它只需要在本地执行 send1

  7. 最后,t1t2t4 的梯度将在“分布式自动微分上下文”中累积。

SMART 模式算法

该算法的详细信息尚未完成,但你可以参考RFC 中的分布式自动微分算法智能模式部分来了解其基本思路。

分布式优化器

DistributedOptimizer 按照如下方式运行:

  1. 接收一个包含远程参数(RRef)的列表进行优化。这些参数也可以是被封装在本地 RRef 中的本地参数。

  2. 使用一个Optimizer 类作为本地优化器,在所有不同的 RRef 所有者上进行运行。

  3. 分布式优化器在每个工作节点上创建一个本地 Optimizer 实例,并保持对其的 RRef 引用。

  4. 当调用torch.distributed.optim.DistributedOptimizer.step()时,分布式优化器会使用 RPC 在适当的远程工作者上执行所有本地优化器。此外,在调用该方法时必须提供一个分布式自动微分的context_id,以便本地优化器能够应用存储在相应上下文中的梯度。

  5. 如果多个并发的分布式优化器在同一工作节点上更新相同的参数,这些更新将通过锁机制依次执行。

简单端到端示例

将所有内容结合在一起,以下是一个使用分布式自动微分和分布式优化器的简单端到端示例。如果将代码保存在名为“dist_autograd_simple.py”的文件中,并通过命令MASTER_ADDR="localhost" MASTER_PORT=29500 python dist_autograd_simple.py运行。

import torch
import torch.multiprocessing as mp
import torch.distributed.autograd as dist_autograd
from torch.distributed import rpc
from torch import optim
from torch.distributed.optim import DistributedOptimizer

def random_tensor():
    return torch.rand((3, 3), requires_grad=True)

def _run_process(rank, dst_rank, world_size):
    name = "worker{}".format(rank)
    dst_name = "worker{}".format(dst_rank)

    # Initialize RPC.
    rpc.init_rpc(
        name=name,
        rank=rank,
        world_size=world_size
    )

    # Use a distributed autograd context.
    with dist_autograd.context() as context_id:
        # Forward pass (create references on remote nodes).
        rref1 = rpc.remote(dst_name, random_tensor)
        rref2 = rpc.remote(dst_name, random_tensor)
        loss = rref1.to_here() + rref2.to_here()

        # Backward pass (run distributed autograd).
        dist_autograd.backward(context_id, [loss.sum()])

        # Build DistributedOptimizer.
        dist_optim = DistributedOptimizer(
        optim.SGD,
        [rref1, rref2],
        lr=0.05,
        )

        # Run the distributed optimizer step.
        dist_optim.step(context_id)

def run_process(rank, world_size):
    dst_rank = (rank + 1) % world_size
    _run_process(rank, dst_rank, world_size)
    rpc.shutdown()

if __name__ == '__main__':
  # Run world_size workers
  world_size = 2
  mp.spawn(run_process, args=(world_size,), nprocs=world_size)
本页目录