BatchNorm3d
- classtorch.nn.BatchNorm3d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)[源代码]
-
对5D输入进行批规范化处理。
5D 是指在 3D 输入的基础上增加了一个通道维度,形成一个批量数据,如论文 Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift 中所述。
$y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta$均值和标准差是针对每个维度在小批量数据上进行计算的。参数向量$\gamma$ 和 $\beta$ 的大小为C(其中C是输入大小)。默认情况下,$\gamma$ 的元素被设置为1,而$\beta$ 的元素被设置为0。在训练阶段的前向传播中,标准差通过有偏估计器计算得出,等同于
torch.var(input, unbiased=False)
。然而,在移动平均的标准差值则是通过无偏估计器计算得出的,等同于torch.var(input, unbiased=True)
。此外,默认情况下,在训练期间,此层会持续计算并保存其均值和方差的运行估计值,并在评估时用于归一化。这些运行估计值通过默认的
momentum
(动量)0.1 来维护。如果
track_running_stats
设置为False
,该层将不再保存运行估计值,并在评估过程中使用批处理统计数据。注意
这个
momentum
参数与优化器类中使用的以及传统意义上的动量概念不同。数学上,这里的运行统计更新规则是$\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t$,其中 $\hat{x}$ 表示估计的统计量,$x_t$ 是新的观测值。由于批量归一化是在 C 维度上进行的,并在 (N, D, H, W) 切片上计算统计信息,因此通常将其称为体积批量归一化或时空批量归一化。
- 参数
-
-
num_features (int) – $C$,表示预期输入大小为$(N, C, D, H, W)$中的
值 -
eps (float) — 一个添加到分母的值,用于保证数值稳定性。默认值为 1e-5。
-
momentum (Optional[float]) – 用于运行均值和方差计算的参数。可以设置为
None
以使用累积移动平均(即简单平均)。默认值:0.1 -
affine (bool) – 一个布尔值,当设置为
True
时,表示此模块具有可学习的仿射参数。默认值:True
-
track_running_stats (bool) – 一个布尔值,当设置为
True
时,此模块会跟踪运行均值和方差;当设置为False
时,此模块不会进行统计,并将统计缓冲区running_mean
和running_var
初始化为None
。如果这些缓冲区为None
,则在训练和评估模式下始终使用批次统计信息。默认值:True
-
- 形状:
-
-
输入: $(N, C, D, H, W)$
-
输出: $(N, C, D, H, W)$ (与输入的形状相同)
-
示例:
>>> # With Learnable Parameters >>> m = nn.BatchNorm3d(100) >>> # Without Learnable Parameters >>> m = nn.BatchNorm3d(100, affine=False) >>> input = torch.randn(20, 100, 35, 45, 10) >>> output = m(input)