训练脚本

如果你的训练脚本与torch.distributed.launch一起工作,它也将与torchrun一起工作,并且会有一些差异:

  1. 无需手动传递 RANKWORLD_SIZEMASTER_ADDRMASTER_PORT 这些参数。

  2. 可以提供 rdzv_backendrdzv_endpoint。对于大多数用户来说,这通常会被设置为 c10d(参见会合点)。默认情况下,rdzv_backend 创建一个非弹性的会合点,在这种情况下,rdzv_endpoint 保存主地址。

  3. 确保你的脚本中包含 load_checkpoint(path)save_checkpoint(path) 逻辑。当任何数量的工作进程失败时,我们会使用相同的程序参数重启所有工作进程,因此你会丢失到最近检查点之前的进度(参见弹性启动)。

  4. use_env 标志已被移除。如果你之前通过解析 --local-rank 选项来获取本地 rank,现在需要从环境变量 LOCAL_RANK 中获取(例如:int(os.environ["LOCAL_RANK"]))。

以下是一个示例训练脚本,在每个 epoch 都会创建检查点。因此,如果发生故障,最多只会丢失一个完整 epoch 的训练进度。

def main():
     args = parse_args(sys.argv[1:])
     state = load_checkpoint(args.checkpoint_path)
     initialize(state)

     # torch.distributed.run ensures that this will work
     # by exporting all the env vars needed to initialize the process group
     torch.distributed.init_process_group(backend=args.backend)

     for i in range(state.epoch, state.total_num_epochs)
          for batch in iter(state.dataset)
              train(batch, state.model)

          state.epoch += 1
          save_checkpoint(state)

查看我们关于torchelastic兼容训练脚本的具体示例,请访问我们的示例页面。

本页目录