torch.nn.utils.spectral_norm

torch.nn.utils.spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None)[源代码]

对给定模块中的参数应用光谱规范化。

$\mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}$

谱归一化通过使用幂迭代方法计算权重矩阵的谱范数$\sigma$,并在每次forward()调用之前重新缩放权重张量,从而稳定生成对抗网络(GAN)中判别器的训练。如果权重张量的维度大于2,则在幂迭代方法中将其重塑为2D以计算谱范数。

参阅用于生成对抗网络的谱规范化

参数
  • module (nn.Module) - 包含该组件的模块

  • name (str, 可选) – 权重参数的名称

  • n_power_iterations (int, 可选) – 用于计算谱范数的幂迭代次数

  • eps (float, 可选) – 用于在计算范数时保证数值稳定的 epsilon 值

  • dim (int, 可选) – 与输出数量对应的维度,默认值为0,但对于是 ConvTranspose{1,2,3}d 实例的模块,默认值为1

返回值

具有谱范数挂钩的原始模块

返回类型

模块T

注意

此功能已通过torch.nn.utils.parametrizations.spectral_norm() 使用新的参数化功能在 torch.nn.utils.parametrize.register_parametrization() 中重新实现。请使用更新后的版本,此函数将在未来的 PyTorch 版本中被弃用。

示例:

>>> m = spectral_norm(nn.Linear(20, 40))
>>> m
Linear(in_features=20, out_features=40, bias=True)
>>> m.weight_u.size()
torch.Size([40])
本页目录