张量并行 - torch.distributed.tensor.parallel
张量并行性(TP)基于PyTorch分布式张量(DTensor)构建,并提供了多种并行风格:列式、行式和序列并行。
警告
张量并行API处于试验阶段,可能随时发生变化。
使用张量并行化来并行化你的nn.Module
的入口点是:
- torch.distributed.tensor.parallel.parallelize_module(module, device_mesh, parallelize_plan)[源代码]
-
在 PyTorch 中应用张量并行性的方法是根据用户指定的计划并行化模块或子模块。
我们依据 parallelize_plan 对模块或子模块进行并行处理。parallelize_plan 中包含
ParallelStyle
,它指明了用户期望的并行方式。用户还可以为每个模块的完全限定名称(FQN)指定不同的并行样式。
注意,
parallelize_module
只接受一维的DeviceMesh
。如果你有一个二维或 N 维的DeviceMesh
,需要先将其切片为一维子 DeviceMesh,然后再传递给此 API(例如device_mesh["tp"]
)。- 参数
-
-
module (
nn.Module
) – 需要进行并行处理的模块。 -
device_mesh (
DeviceMesh
) – 用于描述 DTensor 设备拓扑结构的对象。 -
parallelize_plan (Union[
ParallelStyle
, Dict[str,ParallelStyle
]]) – 用于模块并行化的计划。它可以是一个包含张量并行输入/输出准备方式的ParallelStyle
对象,也可以是模块 FQN 和其对应的ParallelStyle
对象组成的字典。
-
- 返回值
-
将一个
nn.Module
对象进行并行化处理。 - 返回类型
- 示例:
-
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel >>> from torch.distributed.device_mesh import init_device_mesh >>> >>> # Define the module. >>> m = Model(...) >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()}) >>>
注意
对于复杂的模块架构,如Attention和MLP层,我们建议将不同的ParallelStyles(例如
ColwiseParallel
和RowwiseParallel
)组合起来,并将其作为parallelize_plan传递,以实现所需的分片计算。
张量并行支持以下几种并行风格:
- class torch.distributed.tensor.parallel.ColwiseParallel(*, input_layouts=None, output_layouts=None, use_local_output=True) [源代码]
-
按列方式拆分兼容的nn.Module。目前支持nn.Linear和nn.Embedding。用户可以将它与RowwiseParallel结合使用,以实现更复杂模块(如MLP、Attention)的分片。
- 关键字参数
- 返回值
-
一个表示 nn.Module 列分片的
ParallelStyle
对象。
- 示例:
-
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> m = Model(...) # m is a nn.Module that contains a "w1" nn.Linear submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # By default, the input of the "w1" Linear will be converted to Replicated DTensor >>> # and the output of "w1" will return :class:`torch.Tensor` that shards on the last dim. >>> >>> sharded_mod = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel()}) >>> ...
注意
默认情况下,如果没有指定
output_layouts
,ColwiseParallel
的输出将在最后一个维度上进行分片。如果有需要特定张量形状的操作(即在配对的RowwiseParallel
之前),请记住,如果输出进行了分片,则操作可能需要调整以适应分片后的大小。
- class torch.distributed.tensor.parallel.RowwiseParallel(*, input_layouts=None, output_layouts=None, use_local_output=True) [源代码]
-
按照行的方式将兼容的nn.Module进行分区,目前支持nn.Linear和nn.Embedding。用户可以将其与ColwiseParallel结合使用,以实现更复杂模块(如MLP、Attention)的分片。
- 关键字参数
- 返回值
-
一个表示 nn.Module 行分片的
ParallelStyle
对象。
- 示例:
-
>>> from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> m = Model(...) # m is a nn.Module that contains a "w2" nn.Linear submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # By default, the input of the "w2" Linear will be converted to DTensor that shards on the last dim >>> # and the output of "w2" will return a replicated :class:`torch.Tensor`. >>> >>> sharded_mod = parallelize_module(m, tp_mesh, {"w2": RowwiseParallel()}), >>> ...
- 类torch.distributed.tensor.parallel.SequenceParallel(*, sequence_dim=1, use_local_output=False)[源代码]
-
SequenceParallel 复制一个兼容的
nn.Module
参数,并在序列维度上分片输入的情况下运行分片计算。目前支持的功能包括nn.LayerNorm
、nn.Dropout
以及 RMSNorm 的 Python 实现。这种风格实现了论文《减少大型变压器模型中的激活重新计算》(Reducing Activation Recomputation in Large Transformer Models)中描述的操作。
如果传递给此
nn.Module
的输入是一个torch.Tensor
,它假设该输入已经在序列维度上进行了分片,并将其转换为在序列维度上进行分片的DTensor
。如果传递给此nn.Module
的输入已经是DTensor
,但未在序列维度上进行分片,则会重新分布该输入以使其在序列维度上进行分片。nn.Module
的输出将在序列维度上进行分割。- 关键字参数
- 返回值
-
一个表示
nn.Module
序列并行的ParallelStyle
对象。
- 示例:
-
>>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> m = Model(...) # m is a nn.Module that contains a "norm" nn.LayerNorm submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim >>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`. >>> >>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}), >>> ...
注意
SequenceParallel 风格假设 nn.Module 中如果有权重(例如
nn.LayerNorm
或RMSNorm
),这些权重默认会被初始化为1。如果你对这些模块的权重进行了自定义初始化,需要在并行化之前或之后广播这些权重以确保它们被正确复制。
为了简单地配置 nn.Module 的输入和输出的 DTensor 布局,并执行必要的布局重分布,而不将模块参数分配到 DTensors,可以在调用 parallelize_module
时,在 parallelize_plan
中使用以下 ParallelStyle
:
- class torch.distributed.tensor.parallel.PrepareModuleInput(*, input_layouts=None, desired_input_layouts=None, input_kwarg_layouts=None, desired_input_kwarg_layouts=None, use_local_output=False)[源代码]
-
配置 nn.Module 的输入,使其在运行时根据
input_layouts
将输入张量转换为 DTensor,并根据desired_input_layouts
进行布局重分布。- 关键字参数
-
-
input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – nn.Module 的输入张量的 DTensor 布局,用于将输入张量转换为 DTensors。如果某些输入不是 torch.Tensor 或者不需要转换为 DTensors,则需要指定
None
作为占位符。默认值: None。 -
desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – 输入张量的 nn.Module 所需的 DTensor 布局,用于确保 nn.Module 的输入具有所需的 DTensor 布局。此参数需要与
input_layouts
参数长度相同。默认值为 None。 -
input_kwarg_layouts (Dict[str, Placement]) – nn.Module 的输入 kwargs 的 DTensor 布局,用于将输入的 kwarg 张量转换为 DTensors。默认值:None
-
desired_input_kwarg_layouts – (Dict[str, Placement]): nn.Module 输入 kwargs 的预期 DTensor 布局,用于确保 nn.Module 的输入具有所需的 DTensor 布局。默认值为 None。
-
use_local_output (bool, optional) – 是否使用本地
torch.Tensor
而不是DTensor
作为模块的输入,默认值为 False。
-
- 返回值
-
一个
ParallelStyle
对象,用于准备 nn.Module 输入的分片布局。
- 示例:
-
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # According to the style specified below, the first input of attn will be annotated to Sharded DTensor >>> # and then redistributed to Replicated DTensor. >>> parallelize_module( >>> block, # this can be a submodule or module >>> tp_mesh, >>> parallelize_plan={ >>> "attn": PrepareModuleInput( >>> input_layouts=(Shard(0), None, None, ...), >>> desired_input_layouts=(Replicate(), None, None, ...) >>> ), >>> } >>> )
- 类torch.distributed.tensor.parallel.PrepareModuleOutput(*, output_layouts, desired_output_layouts, use_local_output=True)[源代码]
-
根据
output_layouts
配置nn.Module
的输出,将输出张量在运行时转换为DTensor
,并根据desired_output_layouts
执行布局重分布。- 关键字参数
-
-
output_layouts (Union[Placement, Tuple[Placement]]) – 用于 nn.Module 输出张量的 DTensor 布局。如果输出张量是
torch.Tensor
,则将其转换为 DTensors。如果某些输出不是 torch.Tensor 或不需要转换为 DTensors,则需要指定None
作为占位符。 -
desired_output_layouts (Union[Placement, Tuple[Placement]]) – nn.Module 输出张量的期望 DTensor 布局,用于确保 nn.Module 的输出具有所需的 DTensor 布局。
-
use_local_output (bool, optional) – 是否将模块输出使用本地的
torch.Tensor
而不是DTensor
,默认为 True。
-
- 返回值
-
一个ParallelStyle对象,用于准备nn.Module输出的分区布局。
- 示例:
-
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleOutput >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # According to the style specified below, the output of the TransformerBlock will be converted to Replicated DTensor >>> # and then redistributed to Sharded DTensor. >>> parallelize_module( >>> block, # this can be a submodule or module >>> tp_mesh, >>> parallelize_plan = PrepareModuleOutput( >>> output_layouts=Replicate(), >>> desired_output_layouts=Shard(0) >>> ) >>> )
注意
当使用Shard(dim)
作为上述ParallelStyle
的输入/输出布局时,我们假设在TP运行的DeviceMesh
上,输入和输出激活张量在维度dim
上被均匀分片。例如,由于RowwiseParallel
接受最后一个维度分片的输入,它假定输入张量已经在该维度上被均匀地分片了。对于不均匀分片的情况,可以将DTensor直接传递给分区模块,并使用use_local_output=False
来返回每个ParallelStyle
之后的DTensor,这样DTensor就可以跟踪不均匀分片的信息。
对于像Transformer这样的模型,我们建议用户在parallelize_plan中同时使用ColwiseParallel
和RowwiseParallel
来实现整个模型(包括Attention和MLP)所需的分片。
可以通过以下上下文管理器来支持并行化的交叉熵损失计算(损失并行化):
- torch.distributed.tensor.parallel.loss_parallel()[源代码]
-
一个上下文管理器,用于启用损失并行性,在输入数据按类维度分片的情况下,能够高效地进行并行损失计算。目前只支持交叉熵损失。
在此上下文管理器中,可以像平常一样使用
cross_entropy()
或CrossEntropyLoss
,并假设输入参数满足以下条件。如果有的话,相应的backward()
调用也需要在此上下文管理器中进行。- 参数
-
-
输入 (
DTensor
) – 输入的 logits。假设在类别维度上进行了分片处理。 -
target (Union[
torch.Tensor
,DTensor
]) – 必须是 ground truth 类别索引(当前不支持类别概率)。假设在DeviceMesh
上进行了复制。 -
weight (Union[
torch.Tensor
,DTensor
],可选) – 如果提供,则假定在DeviceMesh
上进行了复制。 -
label_smoothing – 当前未提供支持。
-
- 返回值
-
一个复制的
DTensor
。
示例
这里手动创建了一个分片的DTensor来展示其用法。实际上,它通常是由TP模块产生的输出。
>>> from torch.distributed.tensor.parallel import loss_parallel >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> device_mesh = init_device_mesh("cuda", (8,)) >>> input = torch.randn(4, 16, device="cuda", requires_grad=True) >>> dist_input = distribute_tensor(input, device_mesh, placements=[Shard(1)]) >>> target = torch.randint(16, (4,), device="cuda") >>> with loss_parallel(): >>> loss = F.cross_entropy(dist_input, target, reduction="mean") >>> loss.backward() >>> ...
警告
loss_parallel API 是试验性的,可能随时会改变。