数据并行
- classtorch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)[源代码]
-
在模块层面实现数据并行处理。
此容器通过在指定设备上按批次维度分割输入来并行化给定
模块
的应用(其他对象将在每台设备上复制一次)。在前向传递中,该模块会在每个设备上进行复制,并且每个副本处理一部分输入。在反向传递过程中,来自每个副本的梯度会被累加到原始模块。批次大小应该大于所使用的GPU的数量。
警告
建议使用
DistributedDataParallel
类(而不是此类)来进行多GPU训练,即使只有一个节点也是如此。参见:使用nn.parallel.DistributedDataParallel
而非multiprocessing
或nn.DataParallel
和分布式数据并行处理。任意位置和关键字输入都可以传递给DataParallel,但某些类型会受到特殊处理。张量会在指定的dim上进行分散(默认为0)。元组、列表和字典会被浅拷贝。其他类型的对象将在不同的线程中共享,在模型的前向传递过程中如果对其进行写操作可能会导致数据损坏。
在运行此
DataParallel
模块之前,并行化的module
的参数和缓冲区必须位于device_ids[0]
上。警告
在每次前向传递中,
module
在每个设备上都会被复制。因此,在forward
中对运行中的模块所做的任何更新将会丢失。例如,如果module
有一个计数器属性,并且该属性在每次forward
中都会递增,则由于这些副本在forward
结束后会被销毁,所以计数器将始终保持初始值。然而,DataParallel
保证了位于device[0]
上的副本与基础并行化module
的参数和缓冲区共享存储。因此,在device[0]
上对参数或缓冲区所做的就地更新会被记录下来。例如,BatchNorm2d
和spectral_norm()
依赖于这种行为来更新缓冲区。警告
在
module
及其子模块上定义的前向和后向钩子将被调用len(device_ids)
次,每次输入都位于特定设备上。特别地,这些钩子仅保证相对于相应设备上的操作按正确顺序执行。例如,并不能保证通过register_forward_pre_hook()
设置的每个钩子在所有len(device_ids)
个forward()
调用之前执行,但每个这样的钩子会在相应设备的forward()
调用前执行。警告
当
module
在forward()
中返回一个标量(即0维张量)时,这个包装器会返回一个长度等于数据并行所使用设备数量的向量,并包含每个设备的结果。注意
在使用
DataParallel
包装的Module
中应用pack sequence -> recurrent network -> unpack sequence
模式时,存在一些细微差别。有关详细信息,请参阅FAQ中的我的循环网络无法与数据并行性一起工作部分。- 参数
-
-
module (Module) – 需要进行并行处理的模块
-
device_ids (列表 of 整数 或 torch.device) – 指定的 CUDA 设备(默认:使用所有可用设备)
-
output_device (int 或 torch.device) – 输出的设备位置,默认为 device_ids[0]
-
- 变量
-
module (Module) – 需要进行并行处理的模块
示例:
>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) >>> output = net(input_var) # input_var can be on any device, including CPU