torch.distributed.tensor
注意
torch.distributed.tensor 当前处于 alpha 状态,并且仍在开发中。我们承诺保持文档中列出的大部分 API 的向后兼容性,但在必要时也可能进行 API 更改。
PyTorch 分布式张量(DTensor)
PyTorch DTensor 提供简单且灵活的张量分片原语,透明地处理分布式逻辑,包括分片存储、操作计算以及设备/主机间的集体通信。在与多维分片一起工作时,DTensor 可用于构建不同的并行解决方案,并支持分片状态字典表示。
请参阅 PyTorch 基于 DTensor 的原生并行解决方案示例:
DTensor 采用 SPMD(单程序多数据)编程模型,使用户能够编写分布式程序,就像它是在单一设备上运行的程序一样,并且具有相同的收敛属性。它通过指定 DeviceMesh 和 Placement 提供统一的张量分片布局(DTensor 布局):
-
DeviceMesh使用一个 n 维数组来表示集群的设备拓扑和通信器。 -
Placement描述了逻辑张量在DeviceMesh上的分片布局。DTensor 支持三种放置类型:Shard、Replicate和Partial。
DTensor 类 APIs
DTensor 是 torch.Tensor 的子类。这意味着一旦创建了 DTensor,就可以像使用 torch.Tensor 一样使用它,包括运行不同类型的 PyTorch 操作符,就像在单个设备上运行一样,从而实现 PyTorch 操作符的分布式计算。
除了现有的torch.Tensor方法之外,它还提供了一组额外的方法来与torch.Tensor进行交互,重新分配DTensor布局到一个新的DTensor,并获取所有设备上的完整张量内容等。
- 类torch.distributed.tensor.DTensor(local_tensor, spec, *, requires_grad)
-
DTensor(分布式张量)是torch.Tensor的一个子类,它为多设备上的torch.Tensor提供了单设备类似的编程抽象。通过DeviceMesh和以下类型的Placement描述分布式张量的分片布局(DTensor 布局):-
Shard: 在DeviceMesh维度的设备上,张量在dim维度上进行分片 -
Replicate: 在DeviceMesh维度的设备上复制张量 -
Partial: 张量在DeviceMesh维度的设备上等待减少
在调用PyTorch操作符时,
DTensor会重写这些操作符以执行分片计算,并在必要时发起通信。除了进行操作符计算之外,DTensor还会根据操作符本身的语义正确地转换或传播布局(即DTensor 布局),并生成新的DTensor输出。为了确保在调用PyTorch操作符时
DTensor分片计算的数值正确性,需要将操作符的每个张量参数设置为DTensor。- 返回类型
- 属性device_mesh:DeviceMesh
-
与这个 DTensor 对象相关的
DeviceMesh属性。注意
device_mesh是一个只读属性,无法进行设置。
- 静态 from_local(local_tensor, device_mesh=None, placements=None, *, run_check=False, shape=None, stride=None) [源代码](/_modules/torch/distributed/tensor/_api.html#DTensor.from_local)
-
根据指定的
device_mesh和placements,在每个排名上从本地 torch.Tensor 创建一个DTensor。- 参数
-
-
local_tensor (torch.Tensor) – 每个 rank 上的本地 torch.Tensor。
-
device_mesh (
DeviceMesh, 可选) - 用于放置张量的设备网格。如果未指定,则需要在 DeviceMesh 上下文中进行调用, 默认值为 None。 -
placements (List[
Placement], optional) – 描述如何将本地 torch.Tensor 放置在 DeviceMesh 上的方式,必须与device_mesh.ndim的元素数量相等。
-
- 关键字参数
-
-
run_check (bool, 可选) – 以额外通信为代价,在各个 ranks 上执行完整性检查,验证每个本地张量的元信息以确保正确性。如果在
placements中存在Replicate,设备网格维度上的第一个 rank 的数据将被广播到其他 ranks。默认值:False -
shape (torch.Size, 可选) – 一个整数列表,指定在local_tensor基础上构建的DTensor的大小。如果不同rank上的
local_tensor形状不一致,则必须提供此参数。如果没有提供,默认假设给定的分布式张量在各个rank上均匀分片来计算shape。默认值:None -
stride (元组, 可选) – 一个整数列表,用于指定 DTensor 的步长。如果没有提供,默认情况下假定给定的分布式张量在各个 ranks 上均匀分片。
stride默认值为 None。
-
- 返回值
-
一个
DTensor对象 - 返回类型
注意
当
run_check=False时,确保传入的本地张量在各个 ranks 上是正确的(即对于Shard(dim)布局进行了分片或对于Replicate()布局进行了复制)是用户的责任。否则,创建的 DTensor 的行为将是未定义的。注意
from_local是可微的,创建的 DTensor 对象的 requires_grad 属性将根据 local_tensor 是否需要梯度来决定。
- full_tensor(*, grad_placements=None)[源代码]
-
返回此 DTensor 的完整张量。它会执行必要的通信操作,从设备网格中的其他设备收集本地张量,并将它们拼接在一起。这相当于以下代码的简化形式:
dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()- 关键字参数
-
grad_placements (List[
Placement],可选) – 描述从该函数返回的完整张量的任何未来梯度布局。full_tensor 将 DTensor 转换为完整的 torch.Tensor,并且返回的 torch.tensor 可能在后续代码中不会被用作原始复制的 DTensor 布局。此参数是用户可以提供给 autograd 的提示,告知如果返回张量的梯度布局与原始复制的 DTensor 布局不匹配的情况。如果没有指定,默认假设完整张量的梯度布局为复制。 - 返回值
-
一个表示此 DTensor 完整张量的
torch.Tensor对象。 - 返回类型
注意
full_tensor是可微的。
- propertyplacements:Tuple[Placement,...]
-
此DTensor的placements属性描述了它在DeviceMesh上的布局。
注意
placements是一个只读属性,无法进行设置。
- redistribute(device_mesh=None, placements=None, *, async_op=False)[源代码]
-
redistribute执行必要的集体操作,将当前的 DTensor 从其当前位置重分布到新的位置,或者从当前的 DeviceMesh 重分布到一个新的 DeviceMesh。例如,我们可以通过为每个维度指定 Replicate 位置来将一个 Sharded DTensor 转换为 Replicated DTensor。当在设备网格的一个维度上从当前位置重新分布到新位置时,我们将执行以下操作,包括通信集合和本地操作:
-
Shard(dim)->Replicate():all_gather -
Shard(src_dim)->Shard(dst_dim):all_to_all -
Replicate()->Shard(dim): 当前切分(即torch.chunk) -
Partial()->Replicate():all_reduce -
Partial()->Shard(dim):reduce_scatter
redistribute能正确计算无论是在一维还是多维设备网格上创建的DTensor所需的所有重新分布步骤。- 参数
-
-
device_mesh (
DeviceMesh, 可选) - 用于放置 DTensor 的设备网格。如果未指定,默认使用当前 DTensor 的设备网格。默认值:None -
placements (List[
Placement], optional) – 描述如何将 DTensor 放入 DeviceMesh 中的新放置方式,必须与device_mesh.ndim的元素数量相同。默认情况下,在所有网格维度上进行复制。
-
- 关键字参数
-
async_op (bool, optional) – 是否异步执行 DTensor 重分布操作。默认为 False。
- 返回值
-
一个
DTensor对象 - 返回类型
注意
redistribute是可微的,因此用户无需担心重分配操作的反向计算公式。注意
redistribute当前仅支持在同一 DeviceMesh 内重新分布 DTensor。如果需要将 DTensor 重新分布到不同的 DeviceMesh,请提交一个问题。 -
- to_local(*, grad_placements=None)[源代码]
-
获取此DTensor在其当前rank上的本地张量。对于分片,它返回逻辑张量视图的本地分片;对于复制,它返回当前rank上的副本。
- 关键字参数
-
grad_placements (List[
Placement],可选) – 描述从该函数返回的张量的未来梯度布局。to_local 将 DTensor 转换为本地张量,并且返回的本地张量可能不会在后续代码中继续使用原始 DTensor 布局。此参数是用户可以提供给自动微分器的提示,告知其如果返回张量的梯度布局与原始 DTensor 不匹配的情况。如果没有指定,默认假设梯度布局保持不变,并以此进行计算。 - 返回值
-
A
torch.Tensor或AsyncCollectiveTensor对象。它表示当前 rank 上的本地张量。当返回一个AsyncCollectiveTensor对象时,说明本地张量还未准备好(即通信尚未完成)。此时,用户需要调用wait方法来等待本地张量准备就绪。 - 返回类型
注意
to_local是可微的,返回的本地张量的requires_grad属性取决于 DTensor 是否设置了requires_grad。
-
DeviceMesh 作为分布式通讯工具
DeviceMesh 是基于 DTensor 构建的,用于描述集群设备拓扑并表示多维通信器(在 ProcessGroup 之上)。有关如何创建和使用 DeviceMesh 的详细信息,请参阅DeviceMesh 实用技巧。
DTensor放置类型
DTensor 支持以下类型的 Placement,在每个 DeviceMesh 维度上:
- 类torch.distributed.tensor.placement_types.Shard(dim)[源代码]
-
The
Shard(dim)placement describes how the DTensor is sharded along tensor dimensiondimover a correspondingDeviceMeshdimension, where each rank on the DeviceMesh dimension holds only one shard of the global Tensor. TheShard(dim)placement follows thetorch.chunk(dim)semantics: when the tensor dimension is not evenly divisible by the DeviceMesh dimension, some ranks at the end might be empty. TheShardplacement can be used with all DTensor APIs (e.g., distribute_tensor, from_local).- 参数
-
dim (int) – 表示 DTensor 在其相应 DeviceMesh 维度上进行分区的张量维度。
警告
当张量维度的大小不能被设备网格维度整除时,在设备网格维度上对该张量维度进行分片操作目前还处于实验阶段,未来可能有所变化。
- 类torch.distributed.tensor.placement_types.Replicate[源代码]
-
The
Replicate()placement describes how a DTensor is replicated across a correspondingDeviceMeshdimension, where each rank in the DeviceMesh holds a replica of the global Tensor. TheReplicateplacement can be used by all DTensor APIs (e.g.,distribute_tensor,DTensor.from_local, etc.).
- 类torch.distributed.tensor.placement_types.Partial(reduce_op='sum')[源代码]
-
The
Partial(reduce_op)placement describes the DTensor that is pending reduction on a specifiedDeviceMeshdimension, where each rank on the DeviceMesh dimension holds the partial value of the global Tensor. 用户可以使用redistribute操作将PartialDTensor重新分布到指定的DeviceMesh维度上的Replicate或Shard(dim)布局,这会触发底层必要的通信操作(如allreduce,reduce_scatter)。- 参数
-
reduce_op (str, 可选) – 用于部分 DTensor 的缩减操作,以生成 Replicated 或 Sharded DTensor。仅支持元素级的缩减操作,包括: “sum”,“avg”,“product”,“max”,“min”。默认值为 “sum”。
注意
Partial布局可以通过 DTensor 操作生成,并且只能通过DTensor.from_localAPI 来使用。- reduce_op:str='sum'
- 类torch.distributed.tensor.placement_types.Placement[源代码]
-
这是 Placement 类型的基础类,用于描述 DTensor 如何放置在
DeviceMesh上。结合Placement和DeviceMesh可以定义 DTensor 的布局。它是三种主要的 DTensor Placement 类型:Shard、Replicate和Partial的基础。此类不直接使用,主要用于作为类型的占位符。
不同方式创建DTensor
-
构造一个
DTensor有三种方法: -
-
distribute_tensor()从每个进程上的逻辑或“全局”torch.Tensor创建一个DTensor。这可以用于分片模型的叶子torch.Tensor(例如参数和输入)。 -
DTensor.from_local()从每个 rank 上的本地torch.Tensor创建一个DTensor,这可以用于从非叶子torch.Tensor(例如前向/后向过程中的中间激活张量)创建DTensor。 -
DTensor 提供了专门的张量工厂函数(例如
empty()、ones()和randn()等),通过直接指定DeviceMesh和Placement来创建不同的DTensor。与distribute_tensor()不同,这可以直接在设备上初始化分片内存,而无需先初始化逻辑张量内存再进行分片。
-
将逻辑 torch.Tensor 转换为 DTensor
在 torch.distributed 中,SPMD(单程序多数据)编程模型通过启动多个进程(例如使用 torchrun)来执行同一个程序。这意味着程序中的模型会在不同的进程中进行初始化(即模型可能在 CPU、元设备或直接在 GPU 上进行初始化,前提是内存足够)。
DTensor 提供了一个 distribute_tensor() API,可以将模型权重或张量分片为 DTensors,并在每个进程中从“逻辑”张量创建一个 DTensor。这使得生成的 DTensors 符合单设备语义,这对于数值正确性至关重要。
- torch.distributed.tensor.distribute_tensor(tensor, device_mesh=None, placements=None)
-
将一个叶节点
torch.Tensor(例如 nn.Parameter 或 buffers)根据指定的placements分发到device_mesh中。device_mesh和placements的秩必须相同。tensor是要分发的逻辑或“全局”张量,API 会使用 DeviceMesh 维度中第一个 rank 的tensor作为事实来源以保持单设备语义。如果你希望在 Autograd 计算过程中构造一个 DTensor,请改用DTensor.from_local()。- 参数
-
-
tensor (torch.Tensor) – 要进行分布的 torch.Tensor。如果要在某个维度上对张量进行分区,而该维度不能被该网状结构中的设备数量整除,则使用
torch.chunk语义来分割张量并分散碎片。这种不均匀的分区行为是实验性的,并且可能会发生变化。 -
device_mesh (
DeviceMesh, 可选) - 用于分配张量的设备网格。如果未指定,则需要在 DeviceMesh 上下文管理器中进行调用,默认值为 None。 -
placements (List[
Placement], optional) – 描述如何在DeviceMesh上放置张量的位置信息,必须与device_mesh.ndim的元素数量相同。如果没有指定,默认会从device_mesh每个维度的第一个rank开始将张量在整个device_mesh上进行复制。
-
- 返回值
-
一个
DTensor或XLAShardedTensor对象。 - 返回类型
注意
当使用
xla设备类型初始化DeviceMesh时,distribute_tensor将返回XLAShardedTensor。更多详情请参见此问题。XLA的集成是实验性的,可能会发生变化。
除了distribute_tensor()之外,DTensor 还提供了distribute_module() API,以便在nn.Module级别更轻松地进行分片。
- torch.distributed.tensor.distribute_module(module, device_mesh=None, partition_fn=None, input_fn=None, output_fn=None)
-
该函数提供了三种功能来控制模块的参数、输入和输出。
1. 在运行时执行之前通过指定
partition_fn(即允许用户根据指定的partition_fn将模块参数转换为DTensor参数)对模块进行分片。2. 通过在运行时执行期间指定input_fn和output_fn来控制模块的输入或输出。(即,将输入转换为DTensor,将输出重新转换回torch.Tensor)- 参数
-
-
module (
nn.Module) – 待划分的用户模块。 -
device_mesh (
DeviceMesh) – 设备网格,用于放置模块。 -
partition_fn (Callable) – 用于划分参数的函数(即在
device_mesh上分片某些参数)。如果未指定partition_fn,则默认情况下会将module的所有模块参数在整个网格中进行复制。 -
input_fn (Callable) – 指定输入分布,可以控制模块输入的分片方式。该函数将被安装为模块的
forward_pre_hook(前向钩子)。 -
output_fn (Callable) – 指定输出分布,可以控制输出的分片方式或将其转换回 torch.Tensor。
output_fn将被安装为模块forward_hook(前向传播后的钩子)。
-
- 返回值
-
一个包含所有参数和缓冲区都是
DTensor的模块。 - 返回类型
注意
当使用
xla设备类型初始化DeviceMesh时,distribute_module会返回一个带有PyTorch/XLA SPMD注解参数的nn.Module。更多详情请参见此问题。XLA集成是实验性的,可能会发生变化。
DTensor 工厂函数
DTensor 还提供了专用的张量工厂函数,允许直接使用类似 torch.Tensor 的工厂函数 API(如 torch.ones, torch.empty 等)来创建 DTensor,并通过指定 DeviceMesh 和 Placement 来定义所创建的 DTensor:
- torch.distributed.tensor.zeros(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)
-
返回一个填充了标量值 0 的
DTensor。- 参数
-
size (int...) – 定义输出
DTensor形状的一系列整数。这些整数可以作为可变数量的参数传递,也可以作为一个列表或元组等集合形式提供。例如:zeros(1,2,3..)、zeros([1,2,3..]) 或 zeros((1,2,3..)) - 关键参数
-
-
requires_grad (bool, optional) – 是否应在返回的
DTensor上记录自动求梯度操作。默认值:False。 -
dtype (
torch.dtype, 可选) – 返回的DTensor的期望数据类型。默认情况下,如果为None,则使用全局默认值(参见torch.set_default_dtype())。 -
layout (
torch.layout, 可选) – 返回的DTensor的期望布局。默认值:torch.strided。 -
device_mesh –
DeviceMesh类型,包含排名的网格信息 -
placements — 一个包含
Placement类型的序列:如Shard和Replicate
-
- 返回值
-
每个进程中的
DTensor对象 - 返回类型
- torch.distributed.tensor.ones(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)
-
返回一个
DTensor,该张量用标量值 1 填充,并且其形状由变量参数size定义。- 参数
-
size (int...) – 定义输出
DTensor形状的一系列整数。这些整数可以作为可变数量的参数传递,也可以作为一个列表或元组等集合形式提供。例如:ones(1, 2, 3..)、ones([1, 2, 3..]) 或 ones((1, 2, 3..)) - 关键参数
-
-
dtype (
torch.dtype, 可选) – 返回的DTensor的期望数据类型。默认情况下,如果为None,则使用全局默认值(参见torch.set_default_dtype())。 -
layout (
torch.layout, 可选) — 指定返回的 DTensor 的布局。默认值为torch.strided。 -
requires_grad (bool, optional) – 是否应在返回的
DTensor上记录自动求梯度操作。默认值:False。 -
device_mesh –
DeviceMesh类型,包含排名的网格信息 -
placements — 一个包含
Placement类型的序列:如Shard和Replicate
-
- 返回值
-
每个进程中的
DTensor对象 - 返回类型
- torch.distributed.tensor.empty(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)
-
返回一个
DTensor,该张量包含未初始化的数据。DTensor的形状由变量参数size定义。- 参数
-
size (int...) – 定义输出
DTensor形状的整数序列。可以是可变数量的参数,也可以是一个列表或元组等集合形式。例如:empty(1,2,3..)、empty([1,2,3..]) 或 empty((1,2,3..)) - 关键参数
-
-
dtype (
torch.dtype, 可选) – 返回的DTensor的期望数据类型。默认:如果为None,使用全局默认值(参见torch.set_default_dtype())。layout (torch.layout, 可选) – 返回的DTensor的期望布局。默认:torch.strided。 -
requires_grad (bool, optional) – 是否应在返回的
DTensor上记录自动求梯度操作。默认值:False。 -
device_mesh –
DeviceMesh类型,包含排名的网格信息 -
placements — 一个包含
Placement类型的序列:如Shard和Replicate
-
- 返回值
-
每个进程中的
DTensor对象 - 返回类型
- torch.distributed.tensor.full(size, fill_value, *, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)
-
根据
device_mesh和placements返回一个填充了fill_value的DTensor,其形状由参数size定义。- 参数
- 关键参数
-
-
dtype (
torch.dtype, 可选) – 返回的DTensor的期望数据类型。默认情况下,如果为None,则使用全局默认值(参见torch.set_default_dtype())。 -
layout (
torch.layout, 可选) — 指定返回的 DTensor 的布局。默认值为torch.strided。 -
requires_grad (bool, optional) – 是否应在返回的
DTensor上记录自动求梯度操作。默认值:False。 -
device_mesh –
DeviceMesh类型,包含排名的网格信息。 -
placements — 一个包含
Placement类型的序列:如Shard和Replicate
-
- 返回值
-
每个进程中的
DTensor对象 - 返回类型
- torch.distributed.tensor.rand(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)
-
返回一个
DTensor,该张量包含从区间[0, 1)上的均匀分布中随机生成的数字。张量的形状由变量参数size定义。- 参数
-
size (int...) – 定义输出
DTensor形状的一系列整数。这些整数可以作为可变数量的参数传递,也可以作为一个列表或元组等集合形式提供。例如:ones(1, 2, 3..)、ones([1, 2, 3..]) 或 ones((1, 2, 3..)) - 关键参数
-
-
dtype (
torch.dtype, 可选) – 返回的DTensor的期望数据类型。默认情况下,如果为None,则使用全局默认值(参见torch.set_default_dtype())。 -
layout (
torch.layout, 可选) — 指定返回的 DTensor 的布局。默认值为torch.strided。 -
requires_grad (bool, optional) – 是否应在返回的
DTensor上记录自动求梯度操作。默认值:False。 -
device_mesh –
DeviceMesh类型,包含排名的网格信息。 -
placements — 一个包含
Placement类型的序列:如Shard和Replicate
-
- 返回值
-
每个进程中的
DTensor对象 - 返回类型
- torch.distributed.tensor.randn(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)
-
返回一个
DTensor,该张量包含均值为0、方差为1的正态分布随机数。张量的形状由参数size定义。- 参数
-
size (int...) – 定义输出
DTensor形状的一系列整数。这些整数可以作为可变数量的参数传递,也可以作为一个列表或元组等集合形式提供。例如:ones(1, 2, 3..)、ones([1, 2, 3..]) 或 ones((1, 2, 3..)) - 关键参数
-
-
dtype (
torch.dtype, 可选) – 返回的DTensor的期望数据类型。默认情况下,如果为None,则使用全局默认值(参见torch.set_default_dtype())。 -
layout (
torch.layout, 可选) — 指定返回的 DTensor 的布局。默认值为torch.strided。 -
requires_grad (bool, optional) – 是否应在返回的
DTensor上记录自动求梯度操作。默认值:False。 -
device_mesh –
DeviceMesh类型,包含排名的网格信息。 -
placements — 一个包含
Placement类型的序列:如Shard和Replicate
-
- 返回值
-
每个进程中的
DTensor对象 - 返回类型
调试
日志记录
在启动程序时,你可以通过设置 TORCH_LOGS 环境变量(来自torch._logging)来开启额外的日志记录:
-
TORCH_LOGS=+dtensor 会显示 logging.DEBUG 及以上级别所有的日志消息。
-
TORCH_LOGS=dtensor 会显示 logging.INFO 及以上级别的日志消息。
-
TORCH_LOGS=-dtensor 将显示 logging.WARNING 及更高级别的日志消息。
调试工具
为了调试使用了DTensor的程序,并深入了解底层发生的集合操作细节,DTensor提供了一个CommDebugMode:
- 类torch.distributed.tensor.debug.CommDebugMode
-
CommDebugMode是一个上下文管理器,用于计算其上下文中功能集合操作的次数。它通过使用TorchDispatchMode来实现这一点。示例使用方法
mod = ... comm_mode = CommDebugMode() with comm_mode: mod.sum().backward() print(comm_mode.get_comm_counts())
- generate_comm_debug_tracing_table(noise_level=3)[源代码]
-
生成详细的表格,显示模块级别的操作和集体追踪信息。信息的多少取决于噪声级别。
-
显示模块级别的总计数
-
打印除简单操作外的dTensor操作以及模块信息
-
打印不包含在简单操作中的操作
-
显示所有操作
-
- generate_json_dump(file_name='comm_mode_log.json', noise_level=3)[源代码]
-
创建用于构建浏览器可视化效果的JSON文件:0. 打印模块级别的汇总计数;1. 打印不在平凡操作中的dTensor操作;2. 打印不在平凡操作中的其他操作;3. 打印所有操作。
- get_comm_counts()[源代码]
-
返回通信计数,形式为字典。
- 返回值
-
通信内容被视为一个字典。
- 返回类型
-
Dict[Any, int]
- log_comm_debug_tracing_table_to_file(file_name='comm_mode_log.txt', noise_level=3)[源代码]
-
控制台 CommDebugMode 输出的替代方案,将内容写入用户指定的文件
为了可视化一个具有少于三个维度的DTensor的分区情况,DTensor提供了visualize_sharding():
- torch.distributed.tensor.debug.visualize_sharding(dtensor, header='')
-
在终端中可视化
DTensor的一维或二维分片情况。注意
这需要
tabulate包。对于空张量,不会打印分片信息。
实验性特征
DTensor 还提供了一些实验性功能。这些功能要么还在原型阶段,要么基本功能已完成并等待用户反馈。如果你有任何意见或建议,请通过 PyTorch 提交一个 issue。
- torch.distributed.tensor.experimental.local_map(func, out_placements, in_placements=None, device_mesh=None, *, redistribute_inputs=False)
-
local_map()是一个实验性 API,允许用户将DTensor传递给为torch.Tensor编写的函数。它通过提取DTensor的本地组件、调用该函数,并根据out_placements将输出重新包装成DTensor来实现。- 参数
-
-
func (Callable) – 用于每个本地分片的函数。
-
out_placements (Union[PlacementType, Tuple[PlacementType, …]]) – 在
func的扁平化输出中,期望的DTensor放置位置。如果扁平化的output是一个单一值,则out_placements应为 PlacementType 类型;如果有多个值,out_placements则应是一个包含 PlacementType 值的元组,并且这些值与扁平化后的输出一一对应。对于Tensor输出,我们使用 PlacementType 作为其放置位置(一个 Tuple[Placement] 值)。对于非Tensor的输出,PlacementType 应为 None。需要注意的是,在没有传递任何DTensor参数的情况下,即使 out_placements 不是 None,结果函数也应该忽略期望的放置位置,因为该函数并未使用DTensor。 -
in_placements (Tuple[PlacementType, …], 可选) – 在
func的扁平化输入中指定所需的DTensor位置。如果指定了in_placements,local_map()将检查每个DTensor参数的位置是否与所需位置一致。如果不一致且redistribute_inputs为False,将引发异常。如果redistribute_inputs为True,参数将在传递其本地张量给func前重新分布到所需的位置上。唯一例外是当需要的位置不为None而且参数是一个torch.Tensor时,此时将跳过位置检查并将参数直接传递给func。如果in_placements为None,则不会执行任何位置检查。默认值: None -
device_mesh (
DeviceMesh, 可选) – 所有DTensor放置的设备网格。如果没有指定,将从输入的DTensor的设备网格推断出来。local_map 要求每个DTensor都放在同一个设备网格上。默认值:None。 -
redistribute_inputs (bool, 可选) – 布尔值,表示当输入的放置位置与所需输入的放置位置不同时是否重新分配输入的
DTensor。如果此值为False,且某些DTensor输入具有不同的放置位置,则会引发异常。默认值:False。
-
- 返回值
-
一个
Callable,它将func应用于输入DTensor的每个本地分片,并返回由func返回值构成的新的DTensor。 - 异常
-
-
AssertionError – 如果输入的
DTensor没有放置在相同的设备网格上,或者它们被放置在一个与传入的device_mesh参数不同的设备网格上。 -
AssertionError – 对于任何非DTensor的输出,其在
out_placements中的对应位置必须为None。如果不满足此条件,则会引发AssertionError。 -
ValueError – 如果
redistribute_inputs=False,但输入的DTensor需要根据in_placements进行重新分布。
-
示例
>>> def mm_allreduce_forward(device_mesh, W, X): >>> partial_sum_tensor = torch.mm(W, X) >>> reduced_tensor = funcol.all_reduce(partial_sum_tensor, "sum", device_mesh) >>> return reduced_tensor >>> >>> W = torch.randn(12, 8, requires_grad=False) >>> X = torch.randn(8, 16, requires_grad=False) >>> Y = torch.mm(W, X) >>> row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh >>> col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh >>> >>> # local_mm_allreduce_forward is the function wrapped with DTensor/Tensor convertion >>> local_mm_allreduce_forward = local_map( >>> mm_allreduce_forward, >>> out_placements=[Replicate()], >>> in_placements=[col_wise, row_wise], >>> device_mesh=device_mesh, >>> ) >>> >>> W_dt = distribute_tensor(W, device_mesh, (col_wise)) # col-wisely sharded W tensor >>> X_dt = distribute_tensor(X, device_mesh, (row_wise)) # row-wisely sharded X tensor >>> Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) # apply local_mm_allreduce_forward to DTensors
注意
此 API 目前处于实验阶段,可能随时会进行更改。
- torch.distributed.tensor.experimental.register_sharding(op)
-
register_sharding()是一个实验性 API,允许用户在张量输入和输出为 DTensor 时为操作符注册分片策略。这在以下情况下很有用:(1) 操作符op没有默认的分片策略,例如当op是一个不被DTensor支持的自定义操作符;(2) 用户希望覆盖现有操作符的默认分片策略。- 参数
-
op (Union[OpOverload, List[OpOverload]]) – 用于注册自定义分片函数的操作或操作列表。
- 返回值
-
一个用于包装定义操作分片策略函数的装饰器。所定义的分片策略将注册到 DTensor,并在 DTensor 已实现该操作时覆盖默认分片策略。自定义分片函数接收与原始操作相同的参数(如果参数是
torch.Tensor,则会被替换为 DTensor 内部使用的类似张量的对象)。该函数应返回一系列 2 元组序列,每个元组指定可接受的输出位置及其对应的输入位置。
示例
>>> @register_sharding(aten._softmax.default) >>> def custom_softmax_sharding(x, dim, half_to_float): >>> softmax_dim = dim if dim >= 0 else dim + x.ndim >>> acceptable_shardings = [] >>> >>> all_replicate = ([Replicate()], [Replicate(), None, None]) >>> acceptable_shardings.append(all_replicate) >>> >>> for sharding_dim in range(x.ndim): >>> if sharding_dim != softmax_dim: >>> all_sharded = ( >>> [Shard(sharding_dim)], >>> [Shard(sharding_dim), None, None], >>> ) >>> acceptable_shardings.append(all_sharded) >>> >>> return acceptable_shardings
注意
此 API 目前处于实验阶段,可能随时会进行更改。