torch.cuda.comm.gather

torch.cuda.comm.gather(tensors, dim=0, destination=None, *, out=None)[源代码]

从多个GPU设备收集张量。

参数
  • tensors (Iterable[Tensor]) – 需要收集的张量迭代器。除了dim维度之外,所有张量在其他维度上的大小必须相同。

  • dim (int, 可选) – 指定用于连接张量的维度。默认值:0

  • destination (torch.device, str, 或 int, 可选) – 输出设备。可以是 CPU 或 CUDA。默认值:当前的 CUDA 设备。

  • out (Tensor, 可选, 关键字参数) – 用于存储 gather 结果的张量。其大小必须与 tensors 匹配,除了 dim 维度外,该维度的大小应等于所有张量在相应维度上的大小之和。此张量可以位于 CPU 或 CUDA 上。

注意

在指定 out 时,destination 不能被指定。

返回值
  • 如果指定 destination

    位于destination设备上的张量,它是通过沿着dim维度连接tensors得到的结果。

  • 如果指定 out

    张量out现在包含了沿dim维度连接tensors的结果。

本页目录