分布式优化器
警告
当前使用 CUDA 张量时不支持分布式优化器
torch.distributed.optim
提供了 DistributedOptimizer,它接受一个远程参数列表(RRef
),并在参数所在的工作者节点上本地运行优化器。分布式优化器可以使用任何本地优化器的基类在每个工作者上应用梯度。
- classtorch.distributed.optim.DistributedOptimizer(optimizer_class, params_rref, *args, **kwargs)[源代码]
-
DistributedOptimizer 获取分布在各工作节点上的参数的远程引用,并为每个参数本地应用指定的优化器。
此类使用
get_gradients()
来检索特定参数的梯度。来自同一客户端或不同客户端对
step()
的并发调用将在每个工作进程中进行序列化——因为每个工作进程的优化器一次只能处理一组梯度。然而,并不能保证整个前向传播、后向传播和优化过程会按客户端顺序执行。这意味着正在应用的梯度可能不对应于在给定工作进程中执行的最新一轮前向传递。此外,没有跨工作进程的确定顺序。DistributedOptimizer 默认使用 TorchScript 创建本地优化器,因此在多线程训练(如分布式模型并行)中,优化器更新不会被 Python 全局解释器锁 (GIL) 阻塞。此功能目前适用于大多数优化器。你可以参考PyTorch 实用技巧来为自己的自定义优化器启用 TorchScript 支持。
- 参数
-
-
optimizer_class (optim.Optimizer) – 指定在每个工作进程中实例化的优化器类。
-
params_rref (list[RRef]) – 需要优化的本地或远程参数的 RRefs 列表。
-
args - 在每个工作进程中传递给优化器构造函数的参数。
-
kwargs - 在每个工作进程中传递给优化器构造函数的参数。
-
- 示例:
-
>>> import torch.distributed.autograd as dist_autograd >>> import torch.distributed.rpc as rpc >>> from torch import optim >>> from torch.distributed.optim import DistributedOptimizer >>> >>> with dist_autograd.context() as context_id: >>> # Forward pass. >>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) >>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1)) >>> loss = rref1.to_here() + rref2.to_here() >>> >>> # Backward pass. >>> dist_autograd.backward(context_id, [loss.sum()]) >>> >>> # Optimizer. >>> dist_optim = DistributedOptimizer( >>> optim.SGD, >>> [rref1, rref2], >>> lr=0.05, >>> ) >>> dist_optim.step(context_id)
- step(context_id)[源代码]
-
执行一次优化步骤。
这将调用
torch.optim.Optimizer.step()
,在每个包含需要优化参数的工作者上执行,并且会阻塞直到所有工作者返回。提供的context_id
将用于检索相应的context
,其中包含了应应用于参数的梯度。- 参数
-
context_id - 自动微分的上下文ID,用于执行优化器步骤。
- classtorch.distributed.optim.PostLocalSGDOptimizer(optim, averager)[源代码]
-
封装任意的
torch.optim.Optimizer
并运行后局部SGD。此优化器在每一步执行本地优化器操作。经过预热阶段之后,它会在应用本地优化器之后周期性地对参数进行平均。- 参数
-
-
optim (Optimizer) – 本地的优化器。
-
averager (ModelAverager) — 用于运行本地 SGD 算法之后的模型平均实例。
-
示例:
>>> import torch >>> import torch.distributed as dist >>> import torch.distributed.algorithms.model_averaging.averagers as averagers >>> import torch.nn as nn >>> from torch.distributed.optim import PostLocalSGDOptimizer >>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import ( >>> PostLocalSGDState, >>> post_localSGD_hook, >>> ) >>> >>> model = nn.parallel.DistributedDataParallel( >>> module, device_ids=[rank], output_device=rank >>> ) >>> >>> # Register a post-localSGD communication hook. >>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100) >>> model.register_comm_hook(state, post_localSGD_hook) >>> >>> # Create a post-localSGD optimizer that wraps a local optimizer. >>> # Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as >>> # ``start_localSGD_iter`` used in ``PostLocalSGDState``. >>> local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01) >>> opt = PostLocalSGDOptimizer( >>> optim=local_optim, >>> averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100) >>> ) >>> >>> # In the first 100 steps, DDP runs global gradient averaging at every step. >>> # After 100 steps, DDP runs gradient averaging within each subgroup (intra-node by default), >>> # and post-localSGD optimizer runs global model averaging every 4 steps after applying the local optimizer. >>> for step in range(0, 200): >>> opt.zero_grad() >>> loss = loss_fn(output, labels) >>> loss.backward() >>> opt.step()
- load_state_dict(state_dict)[源代码]
-
这与
torch.optim.Optimizer
中的load_state_dict()
方法相同,但还会将模型平均器的步数恢复为state_dict
中保存的值。如果没有
"step"
条目在state_dict
中,它会发出警告,并将模型平均器的步数初始化为 0。
- state_dict()[源代码]
-
这与
torch.optim.Optimizer
的state_dict()
相同,但会添加一个额外的条目来记录模型平均器的步数到检查点中,以确保重新加载时不会再次导致不必要的预热。
- step()[源代码]
-
执行一次优化步骤(参数更新)。
- classtorch.distributed.optim.ZeroRedundancyOptimizer(params, optimizer_class, process_group=None, parameters_as_bucket_view=False, overlap_with_ddp=False, **defaults)[源代码]
-
将任意的
optim.Optimizer
进行包装,并将其状态分布在组内的不同排名上。共享操作按 ZeRO 中所述方式进行。
在每个 rank 中,本地优化器实例仅负责更新大约
1 / world_size
的参数,并且只需要维护相应的1 / world_size
优化器状态。当参数被本地更新后,每个 rank 将其参数广播给所有其他节点以保持所有模型副本同步。ZeroRedundancyOptimizer
可与torch.nn.parallel.DistributedDataParallel
结合使用,从而减少每个 rank 的峰值内存消耗。ZeroRedundancyOptimizer
使用排序贪婪算法,在每个级别上打包一定数量的参数。每个参数只属于一个级别,不会跨级别分配。分区是任意的,可能与参数的注册和使用顺序不一致。- 参数
-
params (
Iterable
) – 参数的集合,包含torch.Tensor
或者dict
类型的对象,这些参数将在不同的 ranks 之间进行分片。 - 关键字参数
-
-
optimizer_class (
torch.nn.Optimizer
) - 本地优化器的类。 -
process_group (
ProcessGroup
, 可选) –torch.distributed
中的ProcessGroup
(默认值为通过torch.distributed.init_process_group()
初始化的dist.group.WORLD
)。 -
parameters_as_bucket_view (bool, optional) – 如果为
True
,参数会被打包到桶中以加快通信速度,并且每个param.data
字段会指向不同偏移量的桶视图;如果为False
,则每个单独的参数会分别进行通信,并且每个params.data
保持不变(默认值:False
)。 -
overlap_with_ddp (bool, 可选) – 如果为
True
,则step()
与DistributedDataParallel
的梯度同步重叠。这需要(1)提供一个功能性优化器或具有功能等效性的优化器作为optimizer_class
参数,以及(2)从ddp_zero_hook.py
中的函数构造并注册一个DDP通信钩子;参数会被打包到与DistributedDataParallel
匹配的桶中,这意味着parameters_as_bucket_view
参数将被忽略。如果为False
,则step()
在反向传播之后独立运行(按常规)。(默认值:False
) -
defaults - 任何尾随的参数将会被传递给本地优化器。
-
示例:
>>> import torch.nn as nn >>> from torch.distributed.optim import ZeroRedundancyOptimizer >>> from torch.nn.parallel import DistributedDataParallel as DDP >>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)]) >>> ddp = DDP(model, device_ids=[rank]) >>> opt = ZeroRedundancyOptimizer( >>> ddp.parameters(), >>> optimizer_class=torch.optim.Adam, >>> lr=0.01 >>> ) >>> ddp(inputs).sum().backward() >>> opt.step()
警告
目前,
ZeroRedundancyOptimizer
要求所有传入的参数必须是同一类型的密集型参数。警告
如果你设置了
overlap_with_ddp=True
,请注意以下事项:当前实现中,重叠使用DistributedDataParallel
和ZeroRedundancyOptimizer
时,在优化器步骤中的前两到三次训练迭代不会执行参数更新。具体来说,如果static_graph=False
,则在第二次前向传递后确定梯度桶策略;如果是static_graph=True
,则在第三次前向传递后确定。为了应对这种情况,可以在输入数据前面添加虚拟数据。警告
ZeroRedundancyOptimizer 是试验性的,可能随时发生变化。
- add_param_group(param_group)[源代码]
-
在
Optimizer
的param_groups
中添加一个参数组。在微调预训练网络时,这一点非常有用,因为可以将冻结的层设为可训练状态,并随着训练过程将其添加到
Optimizer
中。- 参数
-
param_group (dict) – 指定需要优化的参数以及每个组特有的优化选项。
警告
此方法负责更新所有分区上的分片,但必须在所有排名上进行调用。如果只在部分排名上调用该方法,则会导致训练停滞,因为通信原语的调用依赖于管理的参数,并且期望所有排名都参与相同的参数集。
- consolidate_state_dict(to=0)[源代码]
-
在目标 rank 上合并一个
state_dict
列表(每个 rank 有一个)。- 参数
-
to (int) - 指定接收优化器状态的等级(默认为 0)。
- 异常
-
RuntimeError – 如果设置了
overlap_with_ddp=True
并且在调用此ZeroRedundancyOptimizer
实例完成初始化之前就调用了该方法,而初始化发生在DistributedDataParallel
梯度桶被重建之后。
警告
这需要在所有进程中调用。
- 属性join_device:设备
-
返回默认设备。
- join_hook(**kwargs)[源代码]
-
返回 ZeRO 连接钩子。
它通过在优化器步骤中模拟集体通信,使不规则输入的训练成为可能。
在调用此钩子之前,必须正确设置梯度。
- 参数
-
kwargs (dict) – 一个包含关键字参数的字典,用于在运行时修改连接钩子的行为;所有共享相同连接上下文管理器的
Joinable
实例都会收到相同的kwargs
值。
此钩子不支持任何关键字参数;即
kwargs
未被使用。
- 属性join_process_group:Any
-
返回进程组。
- load_state_dict(state_dict)[源代码]
-
从输入的
state_dict
中加载给定排名的状态,并根据需要更新本地优化器。- 参数
-
state_dict (dict) – 优化器的状态;应为调用
state_dict()
返回的对象。 - 异常
-
RuntimeError – 如果设置了
overlap_with_ddp=True
并且在调用此ZeroRedundancyOptimizer
实例完成初始化之前就调用了该方法,而初始化发生在DistributedDataParallel
梯度桶被重建之后。
- state_dict()[源代码]
-
返回该排名所知道的最后一个全局优化器状态。
- 异常
-
RuntimeError – 如果
overlap_with_ddp=True
并且在调用此ZeroRedundancyOptimizer
实例完成初始化之前就调用了该方法,而初始化发生在DistributedDataParallel
梯度桶被重建之后;或者在没有先调用consolidate_state_dict()
的情况下调用了该方法。 - 返回类型