使用分布式 RPC 框架实现参数服务器
作者: Rohan Varma
前提条件:
本教程将通过一个简单的示例,介绍如何使用 PyTorch 的 分布式 RPC 框架 来实现一个参数服务器。参数服务器框架是一种范式,其中一组服务器存储参数(例如大型嵌入表),而多个训练器会查询参数服务器以获取最新的参数。这些训练器可以在本地运行训练循环,并偶尔与参数服务器同步以获取最新参数。要了解更多关于参数服务器方法的信息,请查看 这篇论文。
使用分布式 RPC 框架,我们将构建一个示例,其中多个训练器使用 RPC 与同一个参数服务器通信,并使用 RRef 访问远程参数服务器实例上的状态。每个训练器将通过跨多个节点的分布式 autograd 图拼接,以分布式方式启动其专用的反向传播。
注意:本教程涵盖了分布式 RPC 框架的使用,该框架适用于将模型拆分到多台机器上,或实现参数服务器训练策略,其中网络训练器从不同机器上获取参数。如果您希望在多个 GPU 上复制模型,请参阅 分布式数据并行教程。另外还有另一个 RPC 教程,涵盖了强化学习和 RNN 用例。
让我们从熟悉的内容开始:导入所需的模块并定义一个将在 MNIST 数据集上进行训练的简单卷积神经网络 (ConvNet)。下面的网络主要借鉴了 pytorch/examples 仓库 中定义的网络。
接下来,让我们定义一些辅助函数,这些函数将在脚本的其余部分中非常有用。以下使用 rpc_sync 和 RRef 来定义一个函数,该函数在远程节点上的对象上调用给定方法。在这里,我们对远程对象的句柄由 rref
参数提供,并在其所属节点上运行:rref.owner()
。在调用者节点上,我们通过使用 rpc_sync
同步运行此命令,这意味着我们将阻塞直到收到响应。
现在,我们已经准备好定义我们的参数服务器了。我们将继承 nn.Module
并保存一个指向上述定义的网络的句柄。我们还将保存一个输入设备,该设备将是在调用模型之前输入数据被传输到的设备。
接下来,我们将定义前向传播过程。请注意,无论模型输出的设备是什么,我们都会将输出移动到 CPU,因为 Distributed RPC Framework 目前仅支持通过 RPC 发送 CPU 张量。我们有意禁用了通过 RPC 发送 CUDA 张量,因为调用方/被调用方可能存在不同的设备(CPU/GPU),但在未来的版本中可能会支持这一功能。
接下来,我们将定义一些用于训练和验证的辅助函数。第一个函数 get_dist_gradients
将接收一个分布式自动求导上下文 ID,并调用 dist_autograd.get_gradients
API 以检索由分布式自动求导计算的梯度。更多信息可以在分布式自动求导文档中找到。需要注意的是,我们还会遍历结果字典,并将每个张量转换为 CPU 张量,因为框架目前仅支持通过 RPC 发送张量。接下来,get_param_rrefs
将遍历我们的模型参数,并将它们包装为(本地)RRef。该方法将由训练节点通过 RPC 调用,并返回需要优化的参数列表。这是分布式优化器的必需输入,因为它需要将所有需要优化的参数作为 RRef
列表提供。
最后,我们将创建一些方法来初始化我们的参数服务器。请注意,在所有进程中只会有一个参数服务器的实例,所有训练器将与同一个参数服务器通信,并更新相同的存储模型。如 run_parameter_server
中所示,服务器本身不会执行任何独立操作;它会等待来自训练器(尚未定义)的请求,并通过执行请求的函数来响应这些请求。
请注意,上面的 rpc.shutdown()
并不会立即关闭参数服务器。相反,它会等待所有工作进程(在本例中为训练器)也调用 rpc.shutdown()
。这确保了在所有训练器(尚未定义)完成其训练过程之前,参数服务器不会下线。
接下来,我们将定义 TrainerNet
类。它也将是 nn.Module
的子类,并且我们的 __init__
方法将使用 rpc.remote
API 来获取一个 RRef(远程引用),指向我们的参数服务器。请注意,这里我们并没有将参数服务器复制到本地进程,相反,我们可以将 self.param_server_rref
视为一个分布式共享指针,指向位于另一个进程中的参数服务器。
接下来,我们将定义一个名为 get_global_param_rrefs
的方法。为了理解为什么需要这个方法,值得仔细阅读 DistributedOptimizer 的文档,特别是其 API 签名。优化器必须传入一个 RRef
列表,这些 RRef
对应着需要优化的远程参数,因此我们在这里获取所需的 RRef
。由于给定的 TrainerNet
唯一交互的远程工作是 ParameterServer
,我们只需在 ParameterServer
上调用一个 remote_method
。我们使用了之前在 ParameterServer
类中定义的 get_param_rrefs
方法。该方法将返回一个 RRef
列表,这些 RRef
指向需要优化的参数。需要注意的是,在这个例子中,我们的 TrainerNet
并没有定义自己的参数;如果定义了,我们还需要将每个参数包装在一个 RRef
中,并将其包含在传递给 DistributedOptimizer
的输入中。
现在,我们准备好定义 forward
方法,该方法将调用(同步)RPC 来运行定义在 ParameterServer
上的网络的前向传递。请注意,我们将 self.param_server_rref
(这是我们 ParameterServer
的远程句柄)传递给我们的 RPC 调用。此调用将向运行 ParameterServer
的节点发送一个 RPC,调用前向传递,并返回与模型输出对应的 Tensor
。
在定义了我们的训练器之后,现在是时候编写我们的神经网络训练循环了。这个循环将创建我们的网络和优化器,通过网络运行一些输入并计算损失。训练循环看起来与本地训练程序非常相似,但由于我们的网络是分布式的,因此需要做一些修改。
在下面,我们初始化了 TrainerNet
并构建了一个 DistributedOptimizer
。需要注意的是,正如上面提到的,我们必须传入所有我们希望优化的全局参数(这些参数分布在参与分布式训练的所有节点上)。此外,我们传入了要使用的本地优化器,在这个例子中是 SGD。请注意,我们可以像创建本地优化器一样配置底层优化器算法——所有传递给 optimizer.SGD
的参数都会被正确转发。举个例子,我们传入了一个自定义的学习率,它将作为所有本地优化器的学习率。
接下来,我们定义我们的主训练循环。我们遍历由 PyTorch 的 DataLoader 提供的可迭代对象。在编写典型的前向/反向/优化器循环之前,我们首先将逻辑封装在 分布式自动求导上下文 中。请注意,这是为了记录在模型前向传递中调用的 RPC,以便构建一个包含所有参与分布式工作节点的反向传递的适当图。分布式自动求导上下文返回一个 context_id
,它作为标识符用于累积和优化对应于特定迭代的梯度。
与调用典型的 loss.backward()
不同(这会在本地工作节点上启动反向传播),我们调用 dist_autograd.backward()
并传入 context_id
以及 loss
,其中 loss
是我们希望反向传播开始的根节点。此外,我们将 context_id
传入优化器调用中,这是为了能够查找所有节点上由这次特定反向传播计算出的对应梯度所必需的。
以下代码简单地计算了我们在训练完成后模型的准确率,类似于传统的本地模型。然而,请注意,我们传入此函数的 net
是 TrainerNet
的一个实例,因此前向传播会以透明的方式调用 RPC。
接下来,类似于我们如何将 run_parameter_server
定义为 ParameterServer
的主循环,该循环负责初始化 RPC,让我们为训练器定义一个类似的循环。不同之处在于,我们的训练器必须运行我们上面定义的训练循环:
请注意,与 run_parameter_server
类似,rpc.shutdown()
默认会等待所有 worker(包括训练器和参数服务器)调用 rpc.shutdown()
之后,该节点才会退出。这确保了节点能够优雅地终止,且不会在另一个节点期望其在线时突然离线。
至此,我们已经完成了训练器和参数服务器的特定代码,剩下的就是添加启动训练器和参数服务器的代码。首先,我们必须接收适用于参数服务器和训练器的各种参数。world_size
对应参与训练的节点总数,是所有训练器和参数服务器的总和。我们还需要为每个进程传递一个唯一的 rank
,从 0(我们将在此运行单个参数服务器)到 world_size - 1
。master_addr
和 master_port
是用于标识 rank 0 进程运行位置的参数,各个节点将使用这些参数来相互发现。要在本地测试此示例,只需将 localhost
和相同的 master_port
传递给所有生成的实例。请注意,出于演示目的,此示例仅支持 0-2 个 GPU,尽管该模式可以扩展以利用更多 GPU。
现在,我们将根据命令行参数创建一个对应于参数服务器或训练器的进程。如果传入的rank为0,我们将创建一个ParameterServer
,否则创建一个TrainerNet
。请注意,我们使用torch.multiprocessing
启动一个子进程来执行我们想要运行的函数,并在主线程中使用p.join()
等待该进程完成。在初始化训练器时,我们还使用了PyTorch的dataloaders来指定MNIST数据集上的训练和测试数据加载器。
要在本地运行示例,请在单独的终端窗口中为服务器和每个要启动的工作节点运行以下命令:python rpc_parameter_server.py --world_size=WORLD_SIZE --rank=RANK
。例如,对于一个世界大小为2的主节点,命令应为 python rpc_parameter_server.py --world_size=2 --rank=0
。然后可以在另一个窗口中启动训练器,命令为 python rpc_parameter_server.py --world_size=2 --rank=1
,这将开始使用一个服务器和一个训练器进行训练。请注意,本教程假设训练使用0到2个GPU,可以通过将 --num_gpus=N
参数传递给训练脚本来配置此参数。
您可以通过命令行参数 --master_addr=ADDRESS
和 --master_port=PORT
来指定主工作节点监听的地址和端口。例如,在测试训练器和主节点运行在不同机器上的功能时可以使用这些参数。