PyTorch 入门指南
学习 PyTorch
图像和视频
音频
后端
强化学习
在生产环境中部署 PyTorch 模型
Profiling PyTorch
代码变换与FX
前端API
扩展 PyTorch
模型优化
并行和分布式训练
边缘端的 ExecuTorch
推荐系统
多模态

分布式 RPC 框架入门指南

作者: Shen Li

先决条件:

本教程通过两个简单的示例,展示了如何使用 torch.distributed.rpc 包构建分布式训练,该包最初作为实验性功能在 PyTorch v1.4 中引入。这两个示例的源代码可以在 PyTorch 示例 中找到。

之前的教程,分布式数据并行入门使用 PyTorch 编写分布式应用程序,介绍了 DistributedDataParallel,它支持一种特定的训练范式,即模型在多个进程中复制,每个进程处理一部分输入数据。然而,有时您可能会遇到需要不同训练范式的情况。例如:

  1. 在强化学习中,从环境中获取训练数据可能相对昂贵,而模型本身可能非常小。在这种情况下,生成多个并行运行的观察者并共享一个代理可能会很有用。在这种情况下,代理负责本地训练,但应用程序仍然需要库来在观察者和训练者之间发送和接收数据。

  2. 您的模型可能太大,无法适应单台机器上的 GPU,因此需要一个库来帮助将模型拆分到多台机器上。或者您可能正在实现一个参数服务器训练框架,其中模型参数和训练者位于不同的机器上。

torch.distributed.rpc 包可以帮助处理上述场景。在第一种情况下,RPCRRef 允许将数据从一个工作节点发送到另一个工作节点,同时轻松引用远程数据对象。在第二种情况下,分布式自动求导分布式优化器 使得执行反向传播和优化器步骤就像本地训练一样。在接下来的两节中,我们将通过一个强化学习示例和一个语言模型示例来演示 torch.distributed.rpc 的 API。请注意,本教程的目的不是构建最准确或最有效的模型来解决给定的问题,而是展示如何使用 torch.distributed.rpc 包来构建分布式训练应用程序。

使用 RPC 和 RRef 的分布式强化学习

本节介绍如何使用 RPC 构建一个玩具分布式强化学习模型来解决 OpenAI Gym 中的 CartPole-v1 问题。策略代码主要借鉴了现有的单线程 示例,如下所示。我们将跳过 Policy 设计的细节,重点讨论 RPC 的使用。

importtorch.nnasnn
importtorch.nn.functionalasF

classPolicy(nn.Module):

    def__init__(self):
        super(Policy, self).__init__()
        self.affine1 = nn.Linear(4, 128)
        self.dropout = nn.Dropout(p=0.6)
        self.affine2 = nn.Linear(128, 2)

    defforward(self, x):
        x = self.affine1(x)
        x = self.dropout(x)
        x = F.relu(x)
        action_scores = self.affine2(x)
        return F.softmax(action_scores, dim=1)

我们已经准备好展示观察者(observer)。在这个例子中,每个观察者都会创建自己的环境,并等待代理(agent)的命令来运行一个回合(episode)。在每个回合中,一个观察者最多循环 n_steps 次迭代,在每次迭代中,它使用 RPC 将其环境状态传递给代理,并获取一个动作。然后,它将这个动作应用到环境中,并从环境中获取奖励和下一个状态。之后,观察者使用另一个 RPC 将奖励报告给代理。再次请注意,这显然不是最高效的观察者实现。例如,一个简单的优化可以是将当前状态和上一次的奖励打包到一个 RPC 中,以减少通信开销。然而,我们的目标是展示 RPC API,而不是为 CartPole 构建最佳求解器。因此,在这个示例中,我们将保持逻辑简单,并将这两个步骤明确展示出来。

importargparse
importgym
importtorch.distributed.rpcasrpc

parser = argparse.ArgumentParser(
    description="RPC Reinforcement Learning Example",
    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)

parser.add_argument('--world_size', default=2, type=int, metavar='W',
                    help='number of workers')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
                    help='interval between training status logs')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
                    help='how much to value future rewards')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed  for reproducibility')
args = parser.parse_args()

classObserver:

    def__init__(self):
        self.id = rpc.get_worker_info().id
        self.env = gym.make('CartPole-v1')
        self.env.seed(args.seed)

    defrun_episode(self, agent_rref):
        state, ep_reward = self.env.reset(), 0
        for _ in range(10000):
            # send the state to the agent to get an action
            action = agent_rref.rpc_sync().select_action(self.id, state)

            # apply the action to the environment, and get the reward
            state, reward, done, _ = self.env.step(action)

            # report the reward to the agent for training purpose
            agent_rref.rpc_sync().report_reward(self.id, reward)

            # finishes after the number of self.env._max_episode_steps
            if done:
                break

代理的代码稍微复杂一些,我们将它分解为多个部分。在这个例子中,代理既充当训练器又充当主控器,它向多个分布式观察者发送命令以运行回合,同时还会在本地记录所有动作和奖励,这些数据将在每个回合之后的训练阶段使用。下面的代码展示了 Agent 的构造函数,其中大部分行都在初始化各种组件。最后的循环在其他工作节点上远程初始化观察者,并在本地持有这些观察者的 RRefs。代理稍后将使用这些观察者 RRefs 来发送命令。应用程序无需担心 RRefs 的生命周期。每个 RRef 的所有者都维护一个引用计数映射来跟踪其生命周期,并保证只要存在任何活跃的用户,远程数据对象就不会被删除。详情请参考 RRef 设计文档

importgym
importnumpyasnp

importtorch
importtorch.distributed.rpcasrpc
importtorch.optimasoptim
fromtorch.distributed.rpcimport RRef, rpc_async, remote
fromtorch.distributionsimport Categorical

classAgent:
    def__init__(self, world_size):
        self.ob_rrefs = []
        self.agent_rref = RRef(self)
        self.rewards = {}
        self.saved_log_probs = {}
        self.policy = Policy()
        self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2)
        self.eps = np.finfo(np.float32).eps.item()
        self.running_reward = 0
        self.reward_threshold = gym.make('CartPole-v1').spec.reward_threshold
        for ob_rank in range(1, world_size):
            ob_info = rpc.get_worker_info(OBSERVER_NAME.format(ob_rank))
            self.ob_rrefs.append(remote(ob_info, Observer))
            self.rewards[ob_info.id] = []
            self.saved_log_probs[ob_info.id] = []

接下来,代理向观察者暴露了两个 API,用于选择动作和报告奖励。这些函数仅在代理本地运行,但会通过 RPC 由观察者触发。

classAgent:
    ...
    defselect_action(self, ob_id, state):
        state = torch.from_numpy(state).float().unsqueeze(0)
        probs = self.policy(state)
        m = Categorical(probs)
        action = m.sample()
        self.saved_log_probs[ob_id].append(m.log_prob(action))
        return action.item()

    defreport_reward(self, ob_id, reward):
        self.rewards[ob_id].append(reward)

让我们在 agent 中添加一个 run_episode 函数,该函数通知所有观察者执行一个 episode。在这个函数中,首先创建一个列表来收集异步 RPC 的 futures,然后遍历所有观察者的 RRefs 来发起异步 RPC。在这些 RPC 中,agent 还会将自身的 RRef 传递给观察者,以便观察者也能调用 agent 上的函数。如上所示,每个观察者都会向 agent 发起 RPC,这些是嵌套的 RPC。在每个 episode 结束后,saved_log_probsrewards 将包含记录的动作概率和奖励。

classAgent:
    ...
    defrun_episode(self):
        futs = []
        for ob_rref in self.ob_rrefs:
            # make async RPC to kick off an episode on all observers
            futs.append(
                rpc_async(
                    ob_rref.owner(),
                    ob_rref.rpc_sync().run_episode,
                    args=(self.agent_rref,)
                )
            )

        # wait until all obervers have finished this episode
        for fut in futs:
            fut.wait()

最后,在一轮训练结束后,代理需要训练模型,这一过程在下面的 finish_episode 函数中实现。该函数中没有涉及 RPC 调用,并且大部分代码是从单线程示例中借鉴而来的。因此,我们略过对其内容的详细描述。

classAgent:
    ...
    deffinish_episode(self):
      # joins probs and rewards from different observers into lists
      R, probs, rewards = 0, [], []
      for ob_id in self.rewards:
          probs.extend(self.saved_log_probs[ob_id])
          rewards.extend(self.rewards[ob_id])

      # use the minimum observer reward to calculate the running reward
      min_reward = min([sum(self.rewards[ob_id]) for ob_id in self.rewards])
      self.running_reward = 0.05 * min_reward + (1 - 0.05) * self.running_reward

      # clear saved probs and rewards
      for ob_id in self.rewards:
          self.rewards[ob_id] = []
          self.saved_log_probs[ob_id] = []

      policy_loss, returns = [], []
      for r in rewards[::-1]:
          R = r + args.gamma * R
          returns.insert(0, R)
      returns = torch.tensor(returns)
      returns = (returns - returns.mean()) / (returns.std() + self.eps)
      for log_prob, R in zip(probs, returns):
          policy_loss.append(-log_prob * R)
      self.optimizer.zero_grad()
      policy_loss = torch.cat(policy_loss).sum()
      policy_loss.backward()
      self.optimizer.step()
      return min_reward

通过 PolicyObserverAgent 类,我们已准备好启动多个进程来执行分布式训练。在这个示例中,所有进程都运行相同的 run_worker 函数,并通过 rank 来区分它们的角色。Rank 0 始终是 agent,而其他所有 rank 都是 observers。agent 作为主节点,通过反复调用 run_episodefinish_episode 来工作,直到运行奖励超过环境指定的奖励阈值。所有 observers 则被动等待来自 agent 的命令。代码被 rpc.init_rpcrpc.shutdown 包裹,这两个函数分别用于初始化和终止 RPC 实例。更多细节可以在 API 页面 上找到。

importos
fromitertoolsimport count

importtorch.multiprocessingasmp

AGENT_NAME = "agent"
OBSERVER_NAME="obs{}"

defrun_worker(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    if rank == 0:
        # rank0 is the agent
        rpc.init_rpc(AGENT_NAME, rank=rank, world_size=world_size)

        agent = Agent(world_size)
        print(f"This will run until reward threshold of {agent.reward_threshold}"
                " is reached. Ctrl+C to exit.")
        for i_episode in count(1):
            agent.run_episode()
            last_reward = agent.finish_episode()

            if i_episode % args.log_interval == 0:
                print(f"Episode {i_episode}\tLast reward: {last_reward:.2f}\tAverage reward: "
                    f"{agent.running_reward:.2f}")
            if agent.running_reward > agent.reward_threshold:
                print(f"Solved! Running reward is now {agent.running_reward}!")
                break
    else:
        # other ranks are the observer
        rpc.init_rpc(OBSERVER_NAME.format(rank), rank=rank, world_size=world_size)
        # observers passively waiting for instructions from the agent

    # block until all rpcs finish, and shutdown the RPC instance
    rpc.shutdown()


mp.spawn(
    run_worker,
    args=(args.world_size, ),
    nprocs=args.world_size,
    join=True
)

以下是在 world_size=2 情况下训练时的一些示例输出。

This will run until reward threshold of 475.0 is reached. Ctrl+C to exit.
Episode 10      Last reward: 26.00      Average reward: 10.01
Episode 20      Last reward: 16.00      Average reward: 11.27
Episode 30      Last reward: 49.00      Average reward: 18.62
Episode 40      Last reward: 45.00      Average reward: 26.09
Episode 50      Last reward: 44.00      Average reward: 30.03
Episode 60      Last reward: 111.00     Average reward: 42.23
Episode 70      Last reward: 131.00     Average reward: 70.11
Episode 80      Last reward: 87.00      Average reward: 76.51
Episode 90      Last reward: 86.00      Average reward: 95.93
Episode 100     Last reward: 13.00      Average reward: 123.93
Episode 110     Last reward: 33.00      Average reward: 91.39
Episode 120     Last reward: 73.00      Average reward: 76.38
Episode 130     Last reward: 137.00     Average reward: 88.08
Episode 140     Last reward: 89.00      Average reward: 104.96
Episode 150     Last reward: 97.00      Average reward: 98.74
Episode 160     Last reward: 150.00     Average reward: 100.87
Episode 170     Last reward: 126.00     Average reward: 104.38
Episode 180     Last reward: 500.00     Average reward: 213.74
Episode 190     Last reward: 322.00     Average reward: 300.22
Episode 200     Last reward: 165.00     Average reward: 272.71
Episode 210     Last reward: 168.00     Average reward: 233.11
Episode 220     Last reward: 184.00     Average reward: 195.02
Episode 230     Last reward: 284.00     Average reward: 208.32
Episode 240     Last reward: 395.00     Average reward: 247.37
Episode 250     Last reward: 500.00     Average reward: 335.42
Episode 260     Last reward: 500.00     Average reward: 386.30
Episode 270     Last reward: 500.00     Average reward: 405.29
Episode 280     Last reward: 500.00     Average reward: 443.29
Episode 290     Last reward: 500.00     Average reward: 464.65
Solved! Running reward is now 475.3163778435275!

在本示例中,我们展示了如何使用 RPC 作为通信载体在多个工作节点之间传递数据,以及如何使用 RRef 来引用远程对象。尽管您可以直接在 ProcessGroupsendrecv API 之上构建整个结构,或者使用其他通信/RPC 库,但通过使用 torch.distributed.rpc,您可以获得原生支持以及底层的持续性能优化。

接下来,我们将展示如何将 RPC 和 RRef 与分布式自动求导和分布式优化器结合,以执行分布式模型并行训练。

使用分布式自动求导和分布式优化器的分布式 RNN

在本节中,我们使用一个RNN模型来展示如何使用RPC API构建分布式模型并行训练。这个RNN模型的示例非常小,可以轻松放入单个GPU中,但我们仍然将其层划分到两个不同的工作节点上,以展示这一思路。开发者可以应用类似的技术,将更大的模型分布到多个设备和机器上。

该RNN模型的设计借鉴了PyTorch 示例仓库中的词语言模型,它包含三个主要组件:一个嵌入表、一个LSTM层和一个解码器。下面的代码将嵌入表和解码器封装到子模块中,以便它们的构造函数可以传递给RPC API。在EmbeddingTable子模块中,我们特意将Embedding层放在GPU上,以覆盖这一使用场景。在v1.4版本中,RPC始终在目标工作节点上创建CPU张量参数或返回值。如果函数接受GPU张量,则需要显式地将其移动到合适的设备上。

classEmbeddingTable(nn.Module):
r"""
    Encoding layers of the RNNModel
    """
    def__init__(self, ntoken, ninp, dropout):
        super(EmbeddingTable, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp).cuda()
        self.encoder.weight.data.uniform_(-0.1, 0.1)

    defforward(self, input):
        return self.drop(self.encoder(input.cuda()).cpu()


classDecoder(nn.Module):
    def__init__(self, ntoken, nhid, dropout):
        super(Decoder, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.decoder = nn.Linear(nhid, ntoken)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-0.1, 0.1)

    defforward(self, output):
        return self.decoder(self.drop(output))

有了上述子模块,我们现在可以使用 RPC 将它们组合在一起,创建 RNN 模型。在下面的代码中,ps 表示参数服务器,它托管嵌入表和解码器的参数。构造函数使用 remote API 在参数服务器上创建 EmbeddingTable 对象和 Decoder 对象,并在本地创建 LSTM 子模块。在前向传播过程中,训练器使用 EmbeddingTableRRef 查找远程子模块,并通过 RPC 将输入数据传递给 EmbeddingTable 并获取查找结果。然后,它将嵌入结果通过本地的 LSTM 层进行处理,最后使用另一个 RPC 将输出发送到 Decoder 子模块。总的来说,为了实现分布式模型并行训练,开发者可以将模型划分为多个子模块,调用 RPC 远程创建子模块实例,并在需要时使用 RRef 查找它们。正如你在下面的代码中看到的,它与单机模型并行训练非常相似。主要区别在于用 RPC 函数替代了 Tensor.to(device)

classRNNModel(nn.Module):
    def__init__(self, ps, ntoken, ninp, nhid, nlayers, dropout=0.5):
        super(RNNModel, self).__init__()

        # setup embedding table remotely
        self.emb_table_rref = rpc.remote(ps, EmbeddingTable, args=(ntoken, ninp, dropout))
        # setup LSTM locally
        self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
        # setup decoder remotely
        self.decoder_rref = rpc.remote(ps, Decoder, args=(ntoken, nhid, dropout))

    defforward(self, input, hidden):
        # pass input to the remote embedding table and fetch emb tensor back
        emb = _remote_method(EmbeddingTable.forward, self.emb_table_rref, input)
        output, hidden = self.rnn(emb, hidden)
        # pass output to the rremote decoder and get the decoded output back
        decoded = _remote_method(Decoder.forward, self.decoder_rref, output)
        return decoded, hidden

在介绍分布式优化器之前,让我们先添加一个辅助函数来生成模型参数的 RRef 列表,分布式优化器将使用这些 RRef。在本地训练中,应用程序可以调用 Module.parameters() 来获取所有参数张量的引用,并将其传递给本地优化器以进行后续更新。然而,相同的 API 在分布式训练场景中并不适用,因为一些参数存在于远程机器上。因此,分布式优化器不是接收参数 Tensors 的列表,而是接收 RRefs 的列表,每个模型参数都有一个 RRef,无论是本地还是远程模型参数。这个辅助函数非常简单,只需调用 Module.parameters() 并在每个参数上创建一个本地 RRef

def_parameter_rrefs(module):
    param_rrefs = []
    for param in module.parameters():
        param_rrefs.append(RRef(param))
    return param_rrefs

由于 RNNModel 包含三个子模块,我们需要调用三次 _parameter_rrefs,并将其封装到另一个辅助函数中。

classRNNModel(nn.Module):
    ...
    defparameter_rrefs(self):
        remote_params = []
        # get RRefs of embedding table
        remote_params.extend(_remote_method(_parameter_rrefs, self.emb_table_rref))
        # create RRefs for local parameters
        remote_params.extend(_parameter_rrefs(self.rnn))
        # get RRefs of decoder
        remote_params.extend(_remote_method(_parameter_rrefs, self.decoder_rref))
        return remote_params

现在,我们准备实现训练循环。在初始化模型参数后,我们创建 RNNModelDistributedOptimizer。分布式优化器将接收一系列参数 RRefs,找到所有不同的所有者工作节点,并在每个所有者工作节点上使用给定的参数(即 lr=0.05)创建指定的本地优化器(在本例中为 SGD,你也可以使用其他本地优化器)。

在训练循环中,首先创建一个分布式自动求导上下文,这将帮助分布式自动求导引擎找到梯度和涉及的 RPC 发送/接收函数。分布式自动求导引擎的设计细节可以在其设计文档中找到。然后,它像处理本地模型一样启动前向传播,并运行分布式反向传播。对于分布式反向传播,你只需要指定一个根节点列表,在本例中,它是损失 Tensor。分布式自动求导引擎将自动遍历分布式图并正确写入梯度。接下来,它在分布式优化器上运行 step 函数,这将触及所有涉及的本地优化器以更新模型参数。与本地训练相比,一个小的区别是你不需要运行 zero_grad(),因为每个自动求导上下文都有专门的空间来存储梯度,并且由于我们每次迭代都创建一个上下文,来自不同迭代的梯度不会累积到同一组 Tensors 中。

defrun_trainer():
    batch = 5
    ntoken = 10
    ninp = 2

    nhid = 3
    nindices = 3
    nlayers = 4
    hidden = (
        torch.randn(nlayers, nindices, nhid),
        torch.randn(nlayers, nindices, nhid)
    )

    model = rnn.RNNModel('ps', ntoken, ninp, nhid, nlayers)

    # setup distributed optimizer
    opt = DistributedOptimizer(
        optim.SGD,
        model.parameter_rrefs(),
        lr=0.05,
    )

    criterion = torch.nn.CrossEntropyLoss()

    defget_next_batch():
        for _ in range(5):
            data = torch.LongTensor(batch, nindices) % ntoken
            target = torch.LongTensor(batch, ntoken) % nindices
            yield data, target

    # train for 10 iterations
    for epoch in range(10):
        for data, target in get_next_batch():
            # create distributed autograd context
            with dist_autograd.context() as context_id:
                hidden[0].detach_()
                hidden[1].detach_()
                output, hidden = model(data, hidden)
                loss = criterion(output, target)
                # run distributed backward pass
                dist_autograd.backward(context_id, [loss])
                # run distributed optimizer
                opt.step(context_id)
                # not necessary to zero grads since they are
                # accumulated into the distributed autograd context
                # which is reset every iteration.
        print("Training epoch {}".format(epoch))

最后,我们添加一些粘合代码来启动参数服务器和训练器进程。

defrun_worker(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    if rank == 1:
        rpc.init_rpc("trainer", rank=rank, world_size=world_size)
        _run_trainer()
    else:
        rpc.init_rpc("ps", rank=rank, world_size=world_size)
        # parameter server do nothing
        pass

    # block until all rpcs finish
    rpc.shutdown()


if __name__=="__main__":
    world_size = 2
    mp.spawn(run_worker, args=(world_size, ), nprocs=world_size, join=True)
本页目录