使用 Tensor Parallel (TP) 进行大规模 Transformer 模型训练
作者: Wanchao Liang, Tianyu Liu
本教程演示了如何使用 Tensor Parallel 和 Fully Sharded Data Parallel 在数百到数千个 GPU 上训练一个类似 Transformer 的大型模型。
前提条件:
-
已安装 PyTorch 2.3.0 或更高版本,并支持 CUDA/Linux
Tensor Parallel 的工作原理
Tensor Parallel (TP) 最初在 Megatron-LM 论文中提出,它是一种高效的模型并行技术,用于训练大规模 Transformer 模型。我们在本教程中提到的 Sequence Parallel (SP) 是 Tensor Parallel 的一种变体,它在 nn.LayerNorm
或 RMSNorm
的序列维度上进行分片,以进一步节省训练期间的激活内存。随着模型规模的增大,激活内存成为瓶颈,因此在 Tensor Parallel 训练中,通常会将 Sequence Parallel 应用于 LayerNorm
或 RMSNorm
层。
图 1. 展示了在 Transformer 模型的 MLP 和 Self-Attention 层上的 Tensor Parallel 分片方式,其中注意力机制和 MLP 中的矩阵乘法通过分片计算进行(图片来源)
在高层面上,PyTorch 张量并行的运作方式如下:
分片初始化
-
确定要为每一层应用哪种
ParallelStyle
,并通过调用parallelize_module
来初始化模块并进行分片。 -
并行化的模块将把其模型参数替换为 DTensor,而 DTensor 将负责使用分片计算来运行并行化的模块。
运行时正向/反向传播
-
根据用户为每个
ParallelStyle
指定的输入/输出 DTensor 布局,它将运行适当的通信操作来转换输入/输出的 DTensor 布局(例如allreduce
、allgather
和reduce_scatter
)。 -
运行并行化层的分片计算以节省计算/内存(例如
nn.Linear
、nn.Embedding
)。
何时以及为何应该使用 Tensor 并行
PyTorch 的完全分片数据并行(FSDP)已经具备将模型训练扩展到特定数量 GPU 的能力。然而,当需要在模型规模和 GPU 数量方面进一步扩展模型训练时,会出现许多额外的挑战,可能需要将张量并行与 FSDP 结合使用。
-
随着世界大小(GPU数量)变得异常庞大(超过128/256个GPU),FSDP的集合操作(如
allgather
)主要由环形延迟主导。通过在FSDP之上实现TP/SP,可以将FSDP世界大小减少8倍,使FSDP仅在主机间应用,从而将延迟成本降低相同的幅度。 -
当由于收敛性和GPU内存限制无法将全局批大小提高到超过GPU数量时,数据并行性已达到极限,张量/序列并行是唯一已知的可以“大致”增加全局批大小并继续扩展更多GPU的方法。这意味着模型大小和GPU数量都可以继续扩展。
-
对于某些类型的模型,当本地批大小变小时,TP/SP可以产生更优化浮点运算(FLOPS)的矩阵乘法形状。
那么,在预训练时,达到这些限制的难度如何?目前,预训练一个拥有数十亿或数万亿标记的大语言模型(LLM)可能需要数月时间,即使使用数千个GPU也是如此。
-
在大规模训练LLM时,总会遇到第一个限制。例如,Llama 2 70B使用2000个GPU训练了35天,在2000个GPU的规模下需要多维并行技术。
-
当Transformer模型变得更大时(如Llama2 70B),也会迅速遇到第二个限制。由于内存和收敛性的限制,即使本地
batch_size=1
也无法单独使用FSDP。例如,Llama 2的全局批大小为1K,因此在2000个GPU的规模下无法单独使用数据并行。
如何应用 Tensor Parallel
PyTorch Tensor Parallel API 提供了一组模块级别的原语(ParallelStyle
),用于配置模型各个层的分片策略,包括:
-
ColwiseParallel
和RowwiseParallel
: 以列或行的方式对nn.Linear
和nn.Embedding
进行分片。 -
SequenceParallel
: 对nn.LayerNorm
、nn.Dropout
、RMSNormPython
等进行分片计算。 -
PrepareModuleInput
和PrepareModuleOutput
: 通过适当的通信操作配置模块输入/输出的分片布局。
为了演示如何使用 PyTorch 原生的 Tensor Parallel API,我们将以一个常见的 Transformer 模型为例。在本教程中,我们使用最新的 Llama2 模型 作为参考的 Transformer 模型实现,因为它在社区中也被广泛使用。
由于 Tensor Parallel 将单个张量分片到一组设备上,我们需要首先设置分布式环境(例如 NCCL 通信器)。Tensor Parallelism 是一种单程序多数据(SPMD)分片算法,类似于 PyTorch 的 DDP/FSDP,其底层利用了 PyTorch 的 DTensor 来执行分片。它还使用了 DeviceMesh 抽象(其底层管理着 ProcessGroups)来进行设备管理和分片。要了解如何利用 DeviceMesh 来设置多维并行,请参考 此教程。Tensor Parallel 通常在单个主机内工作,因此我们首先初始化一个连接主机内 8 个 GPU 的 DeviceMesh。
fromtorch.distributed.device_meshimport init_device_mesh
tp_mesh = init_device_mesh("cuda", (8,))
既然我们已经初始化了 DeviceMesh,接下来让我们详细了解一下 Llama 2 模型架构,并探讨如何进行 Tensor Parallel 的分片处理。我们将重点关注核心的 TransformerBlock
,其中 Transformer 模型通过堆叠相同的 TransformerBlock
来扩展模型规模。
核心的 TransformerBlock
包含一个 Attention
层和一个 FeedForward
层。首先让我们来看一下较为简单的 FeedForward
层。FeedForward
层由三个线性层组成,它采用了 SwiGLU 风格的多层感知机(MLP),以下是其前向函数的实现:
# forward in the FeedForward layer
defforward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
它同时执行 w1
和 w3
的矩阵乘法,然后使用 w1
/w3
线性投影结果的组合结果执行 w2
的矩阵乘法。这意味着我们可以借鉴 Tensor Parallelism 论文中的思想,将 w1
/w3
的线性层以列方式分片,并将 w2
的线性层以行方式分片,这样在三个层的最后只会发生一次 allreduce
通信。使用 PyTorch 原生的 Tensor Parallel,我们可以简单地为 FeedForward
层创建一个 parallelize_plan
,如下所示:
fromtorch.distributed.tensor.parallelimport ColwiseParallel, RowwiseParallel, parallelize_module
layer_tp_plan = {
# by default ColwiseParallel input layouts is replicated
# and RowwiseParallel output layouts is replicated
"feed_foward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(),
"feed_forward.w3": ColwiseParallel(),
}
这就是我们如何使用 PyTorch Tensor Parallel API 来配置 FeedForward
层的分片。需要注意的是,用户只需指定如何对各个层进行分片,而通信(例如 allreduce
)将在底层自动处理。
接下来是 Attention
层。它由 wq
、wk
、wv
线性层组成,用于将输入投影到 q
/ k
/ v
,然后通过 wo
线性层执行注意力机制和输出投影。在这里,Tensor Parallelism 的目的是对 q/k/v 投影进行列分片,并对 wo
线性投影进行行分片。因此,我们可以将 Attention 计划添加到刚刚起草的 tp_plan
中:
layer_tp_plan = {
# by default ColwiseParallel input layouts is replicated
# and RowwiseParallel output layouts is replicated
"attention.wq": ColwiseParallel(),
"attention.wk": ColwiseParallel(),
"attention.wv": ColwiseParallel(),
"attention.wo": RowwiseParallel(),
"feed_forward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(),
"feed_forward.w3": ColwiseParallel(),
}
这几乎就是我们需要在 TransformerBlock
上应用 Tensor Parallelism 的 layer_tp_plan
。然而,有一点需要注意的是,当对线性层进行列向分片时,线性层的输出会在最后一个张量维度上被分片,而行向分片的线性层直接接受一个在最后一个维度上分片的输入。如果在列向线性层和行向线性层之间还有其他张量操作(如视图操作),我们需要将相关的形状操作调整为分片后的形状。
对于 Llama 模型,在注意力层中有一些与形状相关的视图操作。特别是对于 wq
/wk
/wv
线性层的列并行处理,激活张量在 num_heads
维度上进行了分片,因此我们需要将 num_heads
调整为局部的 num_heads
。
最后,我们需要调用 parallelize_module
API 来使每个 TransformerBlock
的计划生效。在底层,它将 Attention
和 FeedForward
层中的模型参数分配到 DTensors 中,并根据需要为模型输入和输出(分别在每个模块之前和之后)注册通信钩子:
for layer_id, transformer_block in enumerate(model.layers):
layer_tp_plan = {...} # i.e. the plan we just generated
# Adjust attention module to use the local number of heads
attn_layer = transformer_block.attention
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()
parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
parallelize_plan=layer_tp_plan,
)
既然我们已经详细说明了每个 TransformerBlock
的分片计划,通常在第一层会有一个 nn.Embedding
,而在最后一层会有一个最终的 nn.Linear
投影层,用户可以选择对第一个 nn.Embedding
进行行分片或列分片,并对最后一个 nn.Linear
投影层进行列分片,同时指定适当的输入和输出布局。以下是一个示例:
model = parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
),
"output": ColwiseParallel(
output_layouts=Replicate(),
),
}
)
如果待分区的模型过大,无法放入CPU内存,可以尝试使用
meta
设备初始化(例如,先在meta设备上初始化模型,对层进行分片,然后再实例化模型),或者在Transformer模型初始化期间逐层并行化TransformerBlock
。
将序列并行应用于 LayerNorm/RMSNorm
层
序列并行(Sequence Parallel)建立在上述张量并行(Tensor Parallel)的基础上。与基本的张量并行相比,后者仅在Attention
模块和FeedForward
模块内对张量进行分片,并保持这些模块的输入和输出(即前向传播中的激活值和反向传播中的梯度)为复制状态,而序列并行则在这些模块的序列维度上对它们进行分片。
在一个典型的TransformerBlock
中,前向函数结合了归一化层(LayerNorm
或RMSNorm
)、注意力层、前馈层和残差连接。例如:
# forward in a TransformerBlock
defforward(self, x):
h = x + self.attention(self.attention_norm(x))
out = h + self.feed_forward(self.ffn_norm(h))
return out
在大多数使用场景中,激活值(和梯度)在 Attention
和 FeedForward
模块之外具有 [批量大小, 序列长度, 隐藏维度]
的形状。在 DTensor 的语言中,序列并行(Sequence Parallel)使用 Shard(1)
布局进行模块的前向/后向激活计算。遵循前面的代码示例,下面的代码展示了如何将序列并行应用到 TransformerBlock
中的归一化层:
首先,让我们导入序列并行所需的依赖:
fromtorch.distributed.tensor.parallelimport (
PrepareModuleInput,
SequenceParallel,
)
接下来我们来调整 layer_tp_plan
,以便在 RMSNorm
层上启用序列并行。
layer_tp_plan = {
# Now the input and output of SequenceParallel has Shard(1) layouts,
# to represent the input/output tensors sharded on the sequence dimension
"attention_norm": SequenceParallel(),
"attention": PrepareModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"attention.wq": ColwiseParallel(),
"attention.wk": ColwiseParallel(),
"attention.wv": ColwiseParallel(),
"attention.wo": RowwiseParallel(output_layouts=Shard(1)),
"ffn_norm": SequenceParallel(),
"feed_forward": PrepareModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
"feed_forward.w3": ColwiseParallel(),
}
可以看到,我们现在使用 PrepareModuleInput
将 Attention 和 FeedForward 层的模块输入布局从 Shard(1)
修改为 Replicate()
,并将其输出布局标记为 Shard(1)
。就像在 Tensor Parallelism 中一样,只需指定输入和输出的张量分片布局,层之间的通信将自动进行。
需要注意的是,在 Sequence Parallel 中,我们假设 TransformerBlock
的输入和输出始终在序列维度上进行分片,这样多个 TransformerBlock
可以无缝连接。这可以通过显式指定起始的 nn.Embedding
层的输出和最终的 nn.Linear
投影层的输入为 Shard(1)
来实现:
model = parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
"output": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Replicate()
),
}
)
应用 Loss Parallel
Loss Parallel 是一种相关技术,用于在计算损失函数时节省内存和通信开销,因为模型输出通常非常大。在 Loss Parallel 中,当模型输出在(通常庞大的)词汇维度上分片时,交叉熵损失可以高效计算,而无需将所有模型输出收集到每个 GPU 上。这不仅显著减少了内存消耗,还通过减少通信开销和并行执行分片计算,提高了训练速度。下图简要说明了 Loss Parallel 如何通过分片计算避免将所有模型输出收集到每个 GPU 上。
图 2. 在单个 GPU 上使用 loss parallel 进行交叉熵损失前向计算。蓝色表示分片张量;绿色表示复制的张量;黄色表示具有部分值的张量(待进行 all-reduce 操作)。黑色箭头表示本地计算;红色箭头表示 GPU 之间的功能集合操作。
在 PyTorch Tensor Parallel API 中,可以通过上下文管理器 loss_parallel
启用 Loss Parallel,这样用户可以直接使用 torch.nn.functional.cross_entropy
或 torch.nn.CrossEntropyLoss
,而无需修改代码的其他部分。
要应用 Loss Parallel,模型预测值(通常形状为 [batch size, sequence length, vocabulary size]
)应在词汇维度上进行分片。这可以通过标记最后一个线性投影层输出的布局来轻松实现:
model = parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
"output": ColwiseParallel(
input_layouts=Shard(1),
# use DTensor as the output
use_local_output=False,
),
},
)
在上面的代码中,我们还在输出之前将 Sequence Parallel 应用于 norm 层。我们使用 use_local_output=False
让输出保持为 DTensor,以便与 loss_parallel
上下文管理器配合使用。之后,可以简单地调用交叉熵损失函数,如下所示。请注意,反向计算也需要在上下文中进行。
importtorch.nn.functionalasF
fromtorch.distributed.tensor.parallelimport loss_parallel
pred = model(input_ids)
with loss_parallel():
# assuming pred and labels are of the shape [batch, seq, vocab]
loss = F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))
loss.backward()
将张量并行与全分片数据并行结合使用
既然我们已经展示了如何将 Tensor/Sequence Parallel 应用于模型,接下来让我们也看看 Tensor Parallel 和 Fully Sharded Data Parallel 如何协同工作。由于 Tensor Parallelism 会引入阻塞计算的通信,我们希望确保它在快速的通信通道(如 NVLink)中运行。在实践中,我们通常在每个主机内部应用 Tensor Parallel,而在主机之间应用 Fully Sharded Data Parallel。
这种二维并行模式可以通过二维 DeviceMesh 轻松表达,我们只需将每个“子”DeviceMesh 传递给各自的并行 API 即可:
fromtorch.distributed.device_meshimport init_device_mesh
fromtorch.distributed.tensor.parallelimport ColwiseParallel, RowwiseParallel, parallelize_module
fromtorch.distributed.fsdpimport FullyShardedDataParallel as FSDP
# i.e. 2-D mesh is [dp, tp], training on 64 GPUs that performs 8 way DP and 8 way TP
mesh_2d = init_device_mesh("cuda", (8, 8))
tp_mesh = mesh_2d["tp"] # a submesh that connects intra-host devices
dp_mesh = mesh_2d["dp"] # a submesh that connects inter-host devices
model = Model(...)
tp_plan = {...}
# apply Tensor Parallel intra-host on tp_mesh
model_tp = parallelize_module(model, tp_mesh, tp_plan)
# apply FSDP inter-host on dp_mesh
model_2d = FSDP(model_tp, device_mesh=dp_mesh, use_orig_params=True, ...)
这将使我们能够轻松地在每个主机内(主机内)应用 Tensor Parallel,并在跨主机(主机间)应用 FSDP,且对 Llama 模型实现 零代码更改。Tensor(模型)并行和数据并行技术相结合,提供了继续增加模型规模并利用大量 GPU 进行高效训练的能力。
结论
本教程展示了如何结合使用 Tensor Parallel 和 Fully Sharded Data Parallel 在数百到数千个 GPU 上训练类似 Transformer 的大型模型。教程解释了如何将 Tensor Parallel 应用于模型的不同部分,而无需更改模型本身的代码。Tensor Parallel 是一种用于大规模训练的高效模型并行技术。
要查看本教程中解释的完整端到端代码示例,请参考 pytorch/examples 仓库中的 Tensor Parallel 示例。