torch.nn.parallel.data_parallel

torch.nn.parallel.data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None)[源代码]

在device_ids指定的GPU上并行评估module(input)。

这是 DataParallel 模块的功能实现版。

参数
  • module (Module) – 需要并行评估的模块

  • inputs (Tensor) — 模块的输入数据

  • device_ids (列表 of 整数torch.device) – 指定用于复制模块的 GPU ID 列表

  • output_device (列表 of 整数 or torch.device) – 输出的 GPU 位置。使用 -1 表示在 CPU 上运行。(默认值: device_ids[0])

返回值

位于输出设备上的一个张量,包含模块(input)的结果

返回类型

Tensor

本页目录