PyTorch 入门指南
学习 PyTorch
图像和视频
音频
后端
强化学习
在生产环境中部署 PyTorch 模型
Profiling PyTorch
代码变换与FX
前端API
扩展 PyTorch
模型优化
并行和分布式训练
边缘端的 ExecuTorch
推荐系统
多模态

使用 Tensor Parallel (TP) 进行大规模 Transformer 模型训练

作者: Wanchao Liang, Tianyu Liu

本教程演示了如何使用 Tensor Parallel 和 Fully Sharded Data Parallel 在数百到数千个 GPU 上训练一个类似 Transformer 的大型模型。

前提条件:

Tensor Parallel 的工作原理

Tensor Parallel (TP) 最初在 Megatron-LM 论文中提出,它是一种高效的模型并行技术,用于训练大规模 Transformer 模型。我们在本教程中提到的 Sequence Parallel (SP) 是 Tensor Parallel 的一种变体,它在 nn.LayerNormRMSNorm 的序列维度上进行分片,以进一步节省训练期间的激活内存。随着模型规模的增大,激活内存成为瓶颈,因此在 Tensor Parallel 训练中,通常会将 Sequence Parallel 应用于 LayerNormRMSNorm 层。

Megatron-LM TP

图 1. 展示了在 Transformer 模型的 MLP 和 Self-Attention 层上的 Tensor Parallel 分片方式,其中注意力机制和 MLP 中的矩阵乘法通过分片计算进行(图片来源

在高层面上,PyTorch 张量并行的运作方式如下:

分片初始化

  • 确定要为每一层应用哪种 ParallelStyle,并通过调用 parallelize_module 来初始化模块并进行分片。

  • 并行化的模块将把其模型参数替换为 DTensor,而 DTensor 将负责使用分片计算来运行并行化的模块。

运行时正向/反向传播

  • 根据用户为每个 ParallelStyle 指定的输入/输出 DTensor 布局,它将运行适当的通信操作来转换输入/输出的 DTensor 布局(例如 allreduceallgatherreduce_scatter)。

  • 运行并行化层的分片计算以节省计算/内存(例如 nn.Linearnn.Embedding)。

何时以及为何应该使用 Tensor 并行

PyTorch 的完全分片数据并行(FSDP)已经具备将模型训练扩展到特定数量 GPU 的能力。然而,当需要在模型规模和 GPU 数量方面进一步扩展模型训练时,会出现许多额外的挑战,可能需要将张量并行与 FSDP 结合使用。

  1. 随着世界大小(GPU数量)变得异常庞大(超过128/256个GPU),FSDP的集合操作(如allgather)主要由环形延迟主导。通过在FSDP之上实现TP/SP,可以将FSDP世界大小减少8倍,使FSDP仅在主机间应用,从而将延迟成本降低相同的幅度。

  2. 当由于收敛性和GPU内存限制无法将全局批大小提高到超过GPU数量时,数据并行性已达到极限,张量/序列并行是唯一已知的可以“大致”增加全局批大小并继续扩展更多GPU的方法。这意味着模型大小和GPU数量都可以继续扩展。

  3. 对于某些类型的模型,当本地批大小变小时,TP/SP可以产生更优化浮点运算(FLOPS)的矩阵乘法形状。

那么,在预训练时,达到这些限制的难度如何?目前,预训练一个拥有数十亿或数万亿标记的大语言模型(LLM)可能需要数月时间,即使使用数千个GPU也是如此。

  • 在大规模训练LLM时,总会遇到第一个限制。例如,Llama 2 70B使用2000个GPU训练了35天,在2000个GPU的规模下需要多维并行技术。

  • 当Transformer模型变得更大时(如Llama2 70B),也会迅速遇到第二个限制。由于内存和收敛性的限制,即使本地batch_size=1也无法单独使用FSDP。例如,Llama 2的全局批大小为1K,因此在2000个GPU的规模下无法单独使用数据并行。

如何应用 Tensor Parallel

PyTorch Tensor Parallel API 提供了一组模块级别的原语(ParallelStyle),用于配置模型各个层的分片策略,包括:

  • ColwiseParallelRowwiseParallel: 以列或行的方式对 nn.Linearnn.Embedding 进行分片。

  • SequenceParallel: 对 nn.LayerNormnn.DropoutRMSNormPython 等进行分片计算。

  • PrepareModuleInputPrepareModuleOutput: 通过适当的通信操作配置模块输入/输出的分片布局。

为了演示如何使用 PyTorch 原生的 Tensor Parallel API,我们将以一个常见的 Transformer 模型为例。在本教程中,我们使用最新的 Llama2 模型 作为参考的 Transformer 模型实现,因为它在社区中也被广泛使用。

由于 Tensor Parallel 将单个张量分片到一组设备上,我们需要首先设置分布式环境(例如 NCCL 通信器)。Tensor Parallelism 是一种单程序多数据(SPMD)分片算法,类似于 PyTorch 的 DDP/FSDP,其底层利用了 PyTorch 的 DTensor 来执行分片。它还使用了 DeviceMesh 抽象(其底层管理着 ProcessGroups)来进行设备管理和分片。要了解如何利用 DeviceMesh 来设置多维并行,请参考 此教程。Tensor Parallel 通常在单个主机内工作,因此我们首先初始化一个连接主机内 8 个 GPU 的 DeviceMesh。

fromtorch.distributed.device_meshimport init_device_mesh

tp_mesh = init_device_mesh("cuda", (8,))

既然我们已经初始化了 DeviceMesh,接下来让我们详细了解一下 Llama 2 模型架构,并探讨如何进行 Tensor Parallel 的分片处理。我们将重点关注核心的 TransformerBlock,其中 Transformer 模型通过堆叠相同的 TransformerBlock 来扩展模型规模。

核心的 TransformerBlock 包含一个 Attention 层和一个 FeedForward 层。首先让我们来看一下较为简单的 FeedForward 层。FeedForward 层由三个线性层组成,它采用了 SwiGLU 风格的多层感知机(MLP),以下是其前向函数的实现:

# forward in the FeedForward layer
defforward(self, x):
    return self.w2(F.silu(self.w1(x)) * self.w3(x))

它同时执行 w1w3 的矩阵乘法,然后使用 w1/w3 线性投影结果的组合结果执行 w2 的矩阵乘法。这意味着我们可以借鉴 Tensor Parallelism 论文中的思想,将 w1/w3 的线性层以列方式分片,并将 w2 的线性层以行方式分片,这样在三个层的最后只会发生一次 allreduce 通信。使用 PyTorch 原生的 Tensor Parallel,我们可以简单地为 FeedForward 层创建一个 parallelize_plan,如下所示:

fromtorch.distributed.tensor.parallelimport ColwiseParallel, RowwiseParallel, parallelize_module

layer_tp_plan = {
    # by default ColwiseParallel input layouts is replicated
    # and RowwiseParallel output layouts is replicated
    "feed_foward.w1": ColwiseParallel(),
    "feed_forward.w2": RowwiseParallel(),
    "feed_forward.w3": ColwiseParallel(),
}

这就是我们如何使用 PyTorch Tensor Parallel API 来配置 FeedForward 层的分片。需要注意的是,用户只需指定如何对各个层进行分片,而通信(例如 allreduce)将在底层自动处理。

接下来是 Attention 层。它由 wqwkwv 线性层组成,用于将输入投影到 q/ k/ v,然后通过 wo 线性层执行注意力机制和输出投影。在这里,Tensor Parallelism 的目的是对 q/k/v 投影进行列分片,并对 wo 线性投影进行行分片。因此,我们可以将 Attention 计划添加到刚刚起草的 tp_plan 中:

layer_tp_plan = {
    # by default ColwiseParallel input layouts is replicated
    # and RowwiseParallel output layouts is replicated
    "attention.wq": ColwiseParallel(),
    "attention.wk": ColwiseParallel(),
    "attention.wv": ColwiseParallel(),
    "attention.wo": RowwiseParallel(),
    "feed_forward.w1": ColwiseParallel(),
    "feed_forward.w2": RowwiseParallel(),
    "feed_forward.w3": ColwiseParallel(),
}

这几乎就是我们需要在 TransformerBlock 上应用 Tensor Parallelism 的 layer_tp_plan。然而,有一点需要注意的是,当对线性层进行列向分片时,线性层的输出会在最后一个张量维度上被分片,而行向分片的线性层直接接受一个在最后一个维度上分片的输入。如果在列向线性层和行向线性层之间还有其他张量操作(如视图操作),我们需要将相关的形状操作调整为分片后的形状。

对于 Llama 模型,在注意力层中有一些与形状相关的视图操作。特别是对于 wq/wk/wv 线性层的列并行处理,激活张量在 num_heads 维度上进行了分片,因此我们需要将 num_heads 调整为局部的 num_heads

最后,我们需要调用 parallelize_module API 来使每个 TransformerBlock 的计划生效。在底层,它将 AttentionFeedForward 层中的模型参数分配到 DTensors 中,并根据需要为模型输入和输出(分别在每个模块之前和之后)注册通信钩子:

for layer_id, transformer_block in enumerate(model.layers):
    layer_tp_plan = {...}  # i.e. the plan we just generated

    # Adjust attention module to use the local number of heads
    attn_layer = transformer_block.attention
    attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
    attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()

    parallelize_module(
        module=transformer_block,
        device_mesh=tp_mesh,
        parallelize_plan=layer_tp_plan,
    )

既然我们已经详细说明了每个 TransformerBlock 的分片计划,通常在第一层会有一个 nn.Embedding,而在最后一层会有一个最终的 nn.Linear 投影层,用户可以选择对第一个 nn.Embedding 进行行分片或列分片,并对最后一个 nn.Linear 投影层进行列分片,同时指定适当的输入和输出布局。以下是一个示例:

model = parallelize_module(
    model,
    tp_mesh,
    {
        "tok_embeddings": RowwiseParallel(
            input_layouts=Replicate(),
        ),
        "output": ColwiseParallel(
            output_layouts=Replicate(),
        ),
    }
)

如果待分区的模型过大,无法放入CPU内存,可以尝试使用meta设备初始化(例如,先在meta设备上初始化模型,对层进行分片,然后再实例化模型),或者在Transformer模型初始化期间逐层并行化TransformerBlock

将序列并行应用于 LayerNorm/RMSNorm

序列并行(Sequence Parallel)建立在上述张量并行(Tensor Parallel)的基础上。与基本的张量并行相比,后者仅在Attention模块和FeedForward模块内对张量进行分片,并保持这些模块的输入和输出(即前向传播中的激活值和反向传播中的梯度)为复制状态,而序列并行则在这些模块的序列维度上对它们进行分片。

在一个典型的TransformerBlock中,前向函数结合了归一化层(LayerNormRMSNorm)、注意力层、前馈层和残差连接。例如:

# forward in a TransformerBlock
defforward(self, x):
    h = x + self.attention(self.attention_norm(x))
    out = h + self.feed_forward(self.ffn_norm(h))
    return out

在大多数使用场景中,激活值(和梯度)在 AttentionFeedForward 模块之外具有 [批量大小, 序列长度, 隐藏维度] 的形状。在 DTensor 的语言中,序列并行(Sequence Parallel)使用 Shard(1) 布局进行模块的前向/后向激活计算。遵循前面的代码示例,下面的代码展示了如何将序列并行应用到 TransformerBlock 中的归一化层:

首先,让我们导入序列并行所需的依赖:

fromtorch.distributed.tensor.parallelimport (
    PrepareModuleInput,
    SequenceParallel,
)

接下来我们来调整 layer_tp_plan,以便在 RMSNorm 层上启用序列并行。

layer_tp_plan = {
    # Now the input and output of SequenceParallel has Shard(1) layouts,
    # to represent the input/output tensors sharded on the sequence dimension
    "attention_norm": SequenceParallel(),
    "attention": PrepareModuleInput(
        input_layouts=(Shard(1),),
        desired_input_layouts=(Replicate(),),
    ),
    "attention.wq": ColwiseParallel(),
    "attention.wk": ColwiseParallel(),
    "attention.wv": ColwiseParallel(),
    "attention.wo": RowwiseParallel(output_layouts=Shard(1)),
    "ffn_norm": SequenceParallel(),
    "feed_forward": PrepareModuleInput(
        input_layouts=(Shard(1),),
        desired_input_layouts=(Replicate(),),
    ),
    "feed_forward.w1": ColwiseParallel(),
    "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
    "feed_forward.w3": ColwiseParallel(),
}

可以看到,我们现在使用 PrepareModuleInput 将 Attention 和 FeedForward 层的模块输入布局从 Shard(1) 修改为 Replicate(),并将其输出布局标记为 Shard(1)。就像在 Tensor Parallelism 中一样,只需指定输入和输出的张量分片布局,层之间的通信将自动进行。

需要注意的是,在 Sequence Parallel 中,我们假设 TransformerBlock 的输入和输出始终在序列维度上进行分片,这样多个 TransformerBlock 可以无缝连接。这可以通过显式指定起始的 nn.Embedding 层的输出和最终的 nn.Linear 投影层的输入为 Shard(1) 来实现:

model = parallelize_module(
    model,
    tp_mesh,
    {
        "tok_embeddings": RowwiseParallel(
            input_layouts=Replicate(),
            output_layouts=Shard(1),
        ),
        "norm": SequenceParallel(),
        "output": ColwiseParallel(
            input_layouts=Shard(1),
            output_layouts=Replicate()
        ),
    }
)

应用 Loss Parallel

Loss Parallel 是一种相关技术,用于在计算损失函数时节省内存和通信开销,因为模型输出通常非常大。在 Loss Parallel 中,当模型输出在(通常庞大的)词汇维度上分片时,交叉熵损失可以高效计算,而无需将所有模型输出收集到每个 GPU 上。这不仅显著减少了内存消耗,还通过减少通信开销和并行执行分片计算,提高了训练速度。下图简要说明了 Loss Parallel 如何通过分片计算避免将所有模型输出收集到每个 GPU 上。

loss parallel

图 2. 在单个 GPU 上使用 loss parallel 进行交叉熵损失前向计算。蓝色表示分片张量;绿色表示复制的张量;黄色表示具有部分值的张量(待进行 all-reduce 操作)。黑色箭头表示本地计算;红色箭头表示 GPU 之间的功能集合操作。

在 PyTorch Tensor Parallel API 中,可以通过上下文管理器 loss_parallel 启用 Loss Parallel,这样用户可以直接使用 torch.nn.functional.cross_entropytorch.nn.CrossEntropyLoss,而无需修改代码的其他部分。

要应用 Loss Parallel,模型预测值(通常形状为 [batch size, sequence length, vocabulary size])应在词汇维度上进行分片。这可以通过标记最后一个线性投影层输出的布局来轻松实现:

model = parallelize_module(
    model,
    tp_mesh,
    {
        "tok_embeddings": RowwiseParallel(
            input_layouts=Replicate(),
            output_layouts=Shard(1),
        ),
        "norm": SequenceParallel(),
        "output": ColwiseParallel(
            input_layouts=Shard(1),
            # use DTensor as the output
            use_local_output=False,
        ),
    },
)

在上面的代码中,我们还在输出之前将 Sequence Parallel 应用于 norm 层。我们使用 use_local_output=False 让输出保持为 DTensor,以便与 loss_parallel 上下文管理器配合使用。之后,可以简单地调用交叉熵损失函数,如下所示。请注意,反向计算也需要在上下文中进行。

importtorch.nn.functionalasF
fromtorch.distributed.tensor.parallelimport loss_parallel

pred = model(input_ids)
with loss_parallel():
    # assuming pred and labels are of the shape [batch, seq, vocab]
    loss = F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))
    loss.backward()

将张量并行与全分片数据并行结合使用

既然我们已经展示了如何将 Tensor/Sequence Parallel 应用于模型,接下来让我们也看看 Tensor Parallel 和 Fully Sharded Data Parallel 如何协同工作。由于 Tensor Parallelism 会引入阻塞计算的通信,我们希望确保它在快速的通信通道(如 NVLink)中运行。在实践中,我们通常在每个主机内部应用 Tensor Parallel,而在主机之间应用 Fully Sharded Data Parallel。

fsdp + tp

图 3. FSDP 和 TP 在不同的设备维度上工作,FSDP 通信发生在主机之间,而 TP 通信发生在主机内部。

这种二维并行模式可以通过二维 DeviceMesh 轻松表达,我们只需将每个“子”DeviceMesh 传递给各自的并行 API 即可:

fromtorch.distributed.device_meshimport init_device_mesh
fromtorch.distributed.tensor.parallelimport ColwiseParallel, RowwiseParallel, parallelize_module
fromtorch.distributed.fsdpimport FullyShardedDataParallel as FSDP

# i.e. 2-D mesh is [dp, tp], training on 64 GPUs that performs 8 way DP and 8 way TP
mesh_2d = init_device_mesh("cuda", (8, 8))
tp_mesh = mesh_2d["tp"] # a submesh that connects intra-host devices
dp_mesh = mesh_2d["dp"] # a submesh that connects inter-host devices

model = Model(...)

tp_plan = {...}

# apply Tensor Parallel intra-host on tp_mesh
model_tp = parallelize_module(model, tp_mesh, tp_plan)
# apply FSDP inter-host on dp_mesh
model_2d = FSDP(model_tp, device_mesh=dp_mesh, use_orig_params=True, ...)

这将使我们能够轻松地在每个主机内(主机内)应用 Tensor Parallel,并在跨主机(主机间)应用 FSDP,且对 Llama 模型实现 零代码更改。Tensor(模型)并行和数据并行技术相结合,提供了继续增加模型规模并利用大量 GPU 进行高效训练的能力。

结论

本教程展示了如何结合使用 Tensor Parallel 和 Fully Sharded Data Parallel 在数百到数千个 GPU 上训练类似 Transformer 的大型模型。教程解释了如何将 Tensor Parallel 应用于模型的不同部分,而无需更改模型本身的代码。Tensor Parallel 是一种用于大规模训练的高效模型并行技术。

要查看本教程中解释的完整端到端代码示例,请参考 pytorch/examples 仓库中的 Tensor Parallel 示例

本页目录