torch.cuda.comm.scatter
- torch.cuda.comm.scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, *, out=None)[源代码]
-
将张量分布在多个GPU上。
- 参数
-
-
tensor (Tensor) – 需要进行分散操作的张量。该张量可以在CPU或GPU上运行。
-
devices (Iterable[torch.device, str 或 int], 可选) – 一个 GPU 设备的迭代器,用于在这组设备之间进行分配。
-
chunk_sizes (Iterable[int], 可选) – 每个设备上放置的块大小。它的长度应与
devices
相匹配,并且总和应等于tensor.size(dim)
。如果没有指定,tensor
将被分成大小相等的块。 -
dim (int, 可选) – 沿着该维度对
tensor
进行切分。默认值:0
。 -
streams (Iterable[torch.cuda.Stream], optional) – 一个可迭代的 Streams 对象,在这些 Streams 中执行 scatter 操作。如果未指定,则使用默认 Stream。
-
out (Sequence[Tensor], 可选, 关键字参数) – 用于存储输出结果的 GPU 张量。这些张量的大小必须与
tensor
匹配,除了在dim
维度上,其总大小必须等于tensor.size(dim)
。
-
注意
必须且只能指定
devices
和out
中的一个。如果指定了out
,则不能指定chunk_sizes
,其值将根据out
的大小推断得出。- 返回值
-
-
-
如果指定
devices
, -
一个包含
tensor
片段的元组,这些片段分布在devices
上。
-
如果指定
-
-
如果指定
out
, -
一个包含
out
张量的元组,每个张量都包含tensor
的一部分。
-
如果指定
-