torch.cuda.comm.gather
- torch.cuda.comm.gather(tensors, dim=0, destination=None, *, out=None)[源代码]
-
从多个GPU设备收集张量。
- 参数
-
-
tensors (Iterable[Tensor]) – 需要收集的张量迭代器。除了
dim
维度之外,所有张量在其他维度上的大小必须相同。 -
dim (int, 可选) – 指定用于连接张量的维度。默认值:
0
。 -
destination (torch.device, str, 或 int, 可选) – 输出设备。可以是 CPU 或 CUDA。默认值:当前的 CUDA 设备。
-
out (Tensor, 可选, 关键字参数) – 用于存储 gather 结果的张量。其大小必须与
tensors
匹配,除了dim
维度外,该维度的大小应等于所有张量在相应维度上的大小之和。此张量可以位于 CPU 或 CUDA 上。
-
注意
在指定
out
时,destination
不能被指定。- 返回值
-
-
-
如果指定
destination
, -
位于
destination
设备上的张量,它是通过沿着dim
维度连接tensors
得到的结果。
-
如果指定
-
-
如果指定
out
, -
张量
out
现在包含了沿dim
维度连接tensors
的结果。
-
如果指定
-