SyncBatchNorm

classtorch.nn.SyncBatchNorm(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, process_group=None, device=None, dtype=None)[源代码]

对N维输入进行批处理规范化。

N维输入是一个带有额外通道维度的[N-2]维输入的小批量,具体参见论文Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift 中的相关描述。

$y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta$

均值和标准差是在所有相同进程组的迷你批次中按维度进行计算的。$\gamma$$\beta$ 是大小为C(其中C是输入大小)的学习参数向量。默认情况下,$\gamma$ 的元素从$\mathcal{U}(0, 1)$ 中抽取,而 $\beta$ 的元素设置为 0。标准差是通过有偏估计器计算得出的,等同于torch.var(input, unbiased=False)

此外,默认情况下,在训练期间,此层会持续计算并保存其均值和方差的运行估计值,并在评估时用于归一化。这些运行估计值通过默认的momentum(动量)0.1 来维护。

如果 track_running_stats 设置为 False,该层将不再保存运行估计值,并在评估过程中使用批处理统计数据。

注意

这个 momentum 参数与优化器类中使用的以及传统意义上的动量概念不同。数学上,这里的运行统计更新规则是$\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t$,其中 $\hat{x}$ 表示估计的统计量,$x_t$ 是新的观测值。

由于批量归一化是在C维度的每个通道上进行的,并在(N, +)切片上计算统计信息,因此通常将其称为体积批量归一化或时空批量归一化。

目前,SyncBatchNorm仅支持每个进程使用单个 GPU 的 DistributedDataParallel (DDP)。在使用 DDP 包装网络之前,请使用torch.nn.SyncBatchNorm.convert_sync_batchnorm()BatchNorm*D层转换为SyncBatchNorm

参数
  • num_features (int) – $C$,表示预期输入大小为 $(N, C, +)$ 中的 $C$

  • eps (浮点数) – 用于数值稳定性,添加到分母中的值。默认值:1e-5

  • momentum (Optional[float]) – 用于运行均值和方差计算的参数。可以设置为None以使用累积移动平均(即简单平均)。默认值:0.1

  • affine (bool) – 一个布尔值,当设置为True时,表示此模块具有可学习的仿射参数。默认值: True

  • track_running_stats (bool) – 一个布尔值,当设置为True时,此模块会跟踪运行均值和方差;当设置为False时,此模块不会进行统计,并将统计缓冲区 running_meanrunning_var 初始化为 None。如果这些缓冲区为None,则在训练和评估模式下始终使用批次统计信息。默认值: True

  • process_group (Optional[Any]) – 统计信息的同步在每个进程组内部独立进行。默认情况下,统计信息在整个进程中同步。

形状:
  • 输入: $(N, C, +)$

  • 输出: $(N, C, +)$ (与输入的形状相同)

注意

批量归一化统计的同步仅在训练过程中发生。当设置 model.eval() 或者 self.trainingFalse 时,同步将被禁用。

示例:

>>> # With Learnable Parameters
>>> m = nn.SyncBatchNorm(100)
>>> # creating process group (optional)
>>> # ranks is a list of int identifying rank ids.
>>> ranks = list(range(8))
>>> r1, r2 = ranks[:4], ranks[4:]
>>> # Note: every rank calls into new_group for every
>>> # process group created, even if that rank is not
>>> # part of the group.
>>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
>>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
>>> # Without Learnable Parameters
>>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group)
>>> input = torch.randn(20, 100, 35, 45, 10)
>>> output = m(input)

>>> # network is nn.BatchNorm layer
>>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group)
>>> # only single gpu per process is currently supported
>>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel(
>>>                         sync_bn_network,
>>>                         device_ids=[args.local_rank],
>>>                         output_device=args.local_rank)
classmethodconvert_sync_batchnorm(module, process_group=None)[源代码]

将模型中的所有BatchNorm*D层转换为torch.nn.SyncBatchNorm层。

参数
  • module (nn.Module) – 包含一个或多个BatchNorm*D层的模块

  • process_group (可选) - 用于限定同步范围的进程组,默认值为整个世界。

返回值

原始的module,其中的BatchNorm*D层被转换为新的torch.nn.SyncBatchNorm层对象。

示例:

>>> # Network with nn.BatchNorm layer
>>> module = torch.nn.Sequential(
>>>            torch.nn.Linear(20, 100),
>>>            torch.nn.BatchNorm1d(100),
>>>          ).cuda()
>>> # creating process group (optional)
>>> # ranks is a list of int identifying rank ids.
>>> ranks = list(range(8))
>>> r1, r2 = ranks[:4], ranks[4:]
>>> # Note: every rank calls into new_group for every
>>> # process group created, even if that rank is not
>>> # part of the group.
>>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
>>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
>>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group)
本页目录