全分片数据并行 (FSDP) 入门指南
作者: Hamid Shojanazeri, Yanli Zhao, Shen Li
大规模训练 AI 模型是一项具有挑战性的任务,需要大量的计算能力和资源。同时,处理这些超大型模型的训练也带来了相当大的工程复杂性。PyTorch FSDP 在 PyTorch 1.11 中发布,使得这一过程更加简便。
在本教程中,我们将展示如何使用 FSDP API 来训练简单的 MNIST 模型,这些方法可以扩展到其他更大的模型,例如 HuggingFace BERT 模型 和 参数高达 1T 的 GPT-3 模型。示例的 DDP MNIST 代码由 Patrick Hu 提供。
FSDP 的工作原理
在 DistributedDataParallel (DDP) 训练中,每个进程/工作节点都拥有模型的一个副本,并处理一个批次的数据,最终使用全归约(all-reduce)来汇总不同工作节点上的梯度。在 DDP 中,模型权重和优化器状态会在所有工作节点之间复制。FSDP(Fully Sharded Data Parallel)是一种数据并行方式,它将模型参数、优化器状态和梯度分片到不同的 DDP 节点上。
使用 FSDP 进行训练时,所有工作节点上的 GPU 内存占用比使用 DDP 时更小。这使得某些非常大的模型的训练变得可行,因为它允许更大的模型或批次大小适应设备。不过,这也带来了通信量增加的成本。通过内部优化(如通信与计算的重叠),通信开销得到了减少。
在高层次上,FSDP 的工作原理如下:
在构造函数中
- 分片模型参数,每个 rank 仅保留其自己的分片
在前向传播过程中
-
运行
all_gather
以从所有进程收集所有分片,恢复此 FSDP 单元中的完整参数 -
执行前向计算
-
丢弃刚刚收集的参数分片
在反向传播过程中
-
运行
all_gather
以从所有节点收集所有分片,以恢复此 FSDP 单元中的完整参数 -
运行反向计算
-
运行
reduce_scatter
以同步梯度 -
丢弃参数。
一种理解FSDP分片的方式是将DDP梯度全归约操作分解为规约分散(reduce-scatter)和全收集(all-gather)。具体来说,在反向传播过程中,FSDP对梯度进行规约和分散,确保每个计算节点拥有一部分梯度。然后在优化器步骤中更新相应的参数分片。最后,在随后的前向传播中,执行全收集操作以汇总和合并更新后的参数分片。
如何使用 FSDP
在这里,我们使用一个简单的模型在 MNIST 数据集上进行训练以作演示。相关 API 和逻辑同样适用于训练更大的模型。
设置
1.1 安装 PyTorch 和 Torchvision
有关安装信息,请参阅 入门指南。
我们将以下代码片段添加到 Python 脚本 “FSDP_mnist.py” 中。
1.2 导入必要的包
本教程适用于 PyTorch 1.12 及更高版本。如果您使用的是较早的版本,请将所有
size_based_auto_wrap_policy
替换为default_auto_wrap_policy
,并将fsdp_auto_wrap_policy
替换为auto_wrap_policy
。
# Based on: https://github.com/pytorch/examples/blob/master/mnist/main.py
importos
importargparse
importfunctools
importtorch
importtorch.nnasnn
importtorch.nn.functionalasF
importtorch.optimasoptim
fromtorchvisionimport datasets, transforms
fromtorch.optim.lr_schedulerimport StepLR
importtorch.distributedasdist
importtorch.multiprocessingasmp
fromtorch.nn.parallelimport DistributedDataParallel as DDP
fromtorch.utils.data.distributedimport DistributedSampler
fromtorch.distributed.fsdpimport FullyShardedDataParallel as FSDP
fromtorch.distributed.fsdp.fully_sharded_data_parallelimport (
CPUOffload,
BackwardPrefetch,
)
fromtorch.distributed.fsdp.wrapimport (
size_based_auto_wrap_policy,
enable_wrap,
wrap,
)
1.3 分布式训练设置。正如我们提到的,FSDP 是一种数据并行方式,需要分布式训练环境,因此这里我们使用两个辅助函数来初始化分布式训练进程并进行清理。
defsetup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
defcleanup():
dist.destroy_process_group()
2.1 定义我们的手写数字分类玩具模型。
classNet(nn.Module):
def__init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
defforward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
2.2 定义一个训练函数
deftrain(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
model.train()
ddp_loss = torch.zeros(2).to(rank)
if sampler:
sampler.set_epoch(epoch)
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(rank), target.to(rank)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target, reduction='sum')
loss.backward()
optimizer.step()
ddp_loss[0] += loss.item()
ddp_loss[1] += len(data)
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
if rank == 0:
print('Train Epoch: {}\tLoss: {:.6f}'.format(epoch, ddp_loss[0] / ddp_loss[1]))
2.3 定义验证函数
deftest(model, rank, world_size, test_loader):
model.eval()
correct = 0
ddp_loss = torch.zeros(3).to(rank)
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(rank), target.to(rank)
output = model(data)
ddp_loss[0] += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
ddp_loss[1] += pred.eq(target.view_as(pred)).sum().item()
ddp_loss[2] += len(data)
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
if rank == 0:
test_loss = ddp_loss[0] / ddp_loss[2]
print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
test_loss, int(ddp_loss[1]), int(ddp_loss[2]),
100. * ddp_loss[1] / ddp_loss[2]))
2.4 定义一个分布式训练函数,将模型包装在 FSDP 中
注意:要保存 FSDP 模型,我们需要在每个 rank 上调用 state_dict,然后在 Rank 0 上保存整体状态。
deffsdp_main(rank, world_size, args):
setup(rank, world_size)
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
dataset1 = datasets.MNIST('../data', train=True, download=True,
transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
transform=transform)
sampler1 = DistributedSampler(dataset1, rank=rank, num_replicas=world_size, shuffle=True)
sampler2 = DistributedSampler(dataset2, rank=rank, num_replicas=world_size)
train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1}
test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2}
cuda_kwargs = {'num_workers': 2,
'pin_memory': True,
'shuffle': False}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
my_auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=100
)
torch.cuda.set_device(rank)
init_start_event = torch.cuda.Event(enable_timing=True)
init_end_event = torch.cuda.Event(enable_timing=True)
model = Net().to(rank)
model = FSDP(model)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
init_start_event.record()
for epoch in range(1, args.epochs + 1):
train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
test(model, rank, world_size, test_loader)
scheduler.step()
init_end_event.record()
if rank == 0:
init_end_event.synchronize()
print(f"CUDA event elapsed time: {init_start_event.elapsed_time(init_end_event)/1000}sec")
print(f"{model}")
if args.save_model:
# use a barrier to make sure training is done on all ranks
dist.barrier()
states = model.state_dict()
if rank == 0:
torch.save(states, "mnist_cnn.pt")
cleanup()
2.5 最后,解析参数并设置主函数
if __name__ == '__main__':
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 14)')
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
help='learning rate (default: 1.0)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--save-model', action='store_true', default=False,
help='For Saving the current Model')
args = parser.parse_args()
torch.manual_seed(args.seed)
WORLD_SIZE = torch.cuda.device_count()
mp.spawn(fsdp_main,
args=(WORLD_SIZE, args),
nprocs=WORLD_SIZE,
join=True)
我们记录了 CUDA 事件以测量 FSDP 模型特定部分的时间。CUDA 事件时间为 110.85 秒。
pythonFSDP_mnist.py
CUDAeventelapsedtimeontrainingloop40.67462890625sec
使用 FSDP 包装模型后,模型将如下所示,我们可以看到模型已被包装在一个 FSDP 单元中。接下来,我们将探讨添加 auto_wrap_policy
的情况,并讨论其差异。
FullyShardedDataParallel(
(_fsdp_wrapped_module):FlattenParamsWrapper(
(_fpw_module):Net(
(conv1):Conv2d(1,32,kernel_size=(3,3),stride=(1,1))
(conv2):Conv2d(32,64,kernel_size=(3,3),stride=(1,1))
(dropout1):Dropout(p=0.25,inplace=False)
(dropout2):Dropout(p=0.5,inplace=False)
(fc1):Linear(in_features=9216,out_features=128,bias=True)
(fc2):Linear(in_features=128,out_features=10,bias=True)
)
)
)
以下是使用 PyTorch Profiler 在配备 4 个 GPU 的 g4dn.12.xlarge AWS EC2 实例上,进行 FSDP MNIST 训练时的峰值内存使用情况。
在 FSDP 中应用 auto_wrap_policy,否则 FSDP 会将整个模型放在一个 FSDP 单元中,这会降低计算效率和内存效率。其工作原理是,假设您的模型包含 100 个线性层。如果您执行 FSDP(model)
,则只会有一个 FSDP 单元来包装整个模型。在这种情况下,allgather
将收集所有 100 个线性层的完整参数,因此不会通过参数分片来节省 CUDA 内存。此外,对于所有 100 个线性层,只有一个阻塞的 allgather
调用,层之间不会有通信和计算的重叠。
为了避免这种情况,您可以传入一个 auto_wrap_policy
,它会在满足指定条件(例如大小限制)时自动密封当前的 FSDP 单元并启动一个新的 FSDP 单元。这样,您将拥有多个 FSDP 单元,并且一次只需要一个 FSDP 单元收集完整的参数。例如,假设您有 5 个 FSDP 单元,每个单元封装了 20 个线性层。那么,在前向传播过程中,第一个 FSDP 单元会为前 20 个线性层收集参数,进行计算,随后丢弃这些参数,接着处理接下来的 20 个线性层。因此,在任何时刻,每个 rank 只会实例化 20 个线性层的参数/梯度,而不是 100 个。
在 2.4 版本中,我们定义了 auto_wrap_policy
并将其传递给 FSDP 包装器。在以下示例中,my_auto_wrap_policy
定义了一个层如果其参数数量大于 100,则可以被 FSDP 包装或分片。如果该层的参数数量小于 100,它将会与其他小层一起被 FSDP 包装。找到一个最优的自动包装策略是具有挑战性的,PyTorch 将在未来为此配置添加自动调优功能。在没有自动调优工具的情况下,最好通过实验性地使用不同的自动包装策略来分析您的工作流程,并找到最优的配置。
my_auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=20000
)
torch.cuda.set_device(rank)
model = Net().to(rank)
model = FSDP(model,
auto_wrap_policy=my_auto_wrap_policy)
应用 auto_wrap_policy 后,模型将如下所示:
FullyShardedDataParallel(
(_fsdp_wrapped_module):FlattenParamsWrapper(
(_fpw_module):Net(
(conv1):Conv2d(1,32,kernel_size=(3,3),stride=(1,1))
(conv2):Conv2d(32,64,kernel_size=(3,3),stride=(1,1))
(dropout1):Dropout(p=0.25,inplace=False)
(dropout2):Dropout(p=0.5,inplace=False)
(fc1):FullyShardedDataParallel(
(_fsdp_wrapped_module):FlattenParamsWrapper(
(_fpw_module):Linear(in_features=9216,out_features=128,bias=True)
)
)
(fc2):Linear(in_features=128,out_features=10,bias=True)
)
)
pythonFSDP_mnist.py
CUDAeventelapsedtimeontrainingloop41.89130859375sec
以下是从 PyTorch Profiler 捕获的在使用 auto_wrap 策略的 FSDP 上,在 4 个 GPU 的 g4dn.12.xlarge AWS EC2 实例上进行 MNIST 训练的峰值内存使用情况。可以观察到,与未应用 auto_wrap 策略的 FSDP 相比,每个设备的峰值内存使用量更小,从约 75 MB 减少到 66 MB。
CPU 卸载:如果模型非常大,即使使用 FSDP 也无法完全放入 GPU 中,那么 CPU 卸载在这里可能会有所帮助。
目前,仅支持参数和梯度的 CPU 卸载。可以通过传入 cpu_offload=CPUOffload(offload_params=True)
来启用该功能。
需要注意的是,当前该功能会隐式启用梯度卸载到 CPU,以便参数和梯度位于同一设备上,从而与优化器一起工作。此 API 可能会发生变化。默认值为 None
,在这种情况下不会进行任何卸载。
由于频繁将张量从主机复制到设备,使用此功能可能会显著减慢训练速度,但它有助于提高内存效率并训练更大规模的模型。
在 2.4 版本中,我们将其添加到 FSDP 封装器中。
model = FSDP(model,
auto_wrap_policy=my_auto_wrap_policy,
cpu_offload=CPUOffload(offload_params=True))
将其与 DDP 进行比较,如果在 2.4 中我们只是正常地将模型封装在 DDP 中,并在“DDP_mnist.py”中保存更改。
model = Net().to(rank)
model = DDP(model)
pythonDDP_mnist.py
CUDAeventelapsedtimeontrainingloop39.77766015625sec
以下是从 PyTorch profiler 中捕获的,在具有 4 个 GPU 的 g4dn.12.xlarge AWS EC2 实例上进行的 DDP MNIST 训练的内存峰值使用情况。
考虑到我们在这里定义的玩具示例和小型 MNIST 模型,我们可以观察到 DDP 和 FSDP 的峰值内存使用情况之间的差异。在 DDP 中,每个进程都持有模型的一个副本,因此与 FSDP 相比,内存占用更高,FSDP 会在 DDP 的各个节点之间分片模型参数、优化器状态和梯度。使用 auto_wrap 策略的 FSDP 的峰值内存使用量是最低的,其次是 FSDP 和 DDP。
此外,从时间上看,考虑到小型模型并在单台机器上运行训练,带或不带 auto_wrap 策略的 FSDP 的性能几乎与 DDP 一样快。这个示例并不代表大多数实际应用场景,有关 DDP 和 FSDP 的详细分析和比较,请参考这篇 博客文章。