torch.nn.utils.parametrizations.spectral_norm

torch.nn.utils.parametrizations.spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None)[源代码]

对给定模块中的参数应用光谱规范化。

$\mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}$

当应用到向量上时,它会简化为

$\mathbf{x}_{SN} = \dfrac{\mathbf{x}}{\|\mathbf{x}\|_2}$

谱归一化通过减少生成对抗网络(GAN)中判别器模型的Lipschitz常数来稳定训练。$\sigma$ 通过每次访问权重时执行一次幂迭代 来近似计算。如果权重张量的维度大于2,则在进行幂迭代时将其重塑为二维,以获得谱范数。

参阅用于生成对抗网络的谱规范化

注意

此函数使用 register_parametrization() 中的参数化功能来实现。它是 torch.nn.utils.spectral_norm() 的重新实现。

注意

当注册此约束条件时,与最大奇异值相关的奇异向量会被估计而不是随机采样。然后在使用模块的训练模式访问张量时,通过执行n_power_iterations幂迭代法来更新这些向量。

注意

如果_SpectralNorm模块(即module.parametrization.weight[idx])在被移除时处于训练模式,它会执行一次幂迭代。如果你想避免这次迭代,请在移除该模块之前将其设置为评估模式。

参数
  • module (nn.Module) - 包含该组件的模块

  • name (str, 可选) – 权重参数的名称,默认值为"weight"

  • n_power_iterations (int, 可选) – 计算谱范数的幂迭代次数。默认值为:1

  • eps (float, optional) – 用于计算范数时的数值稳定性参数。默认值:1e-12

  • dim (int, 可选) – 与输出数量对应的维度。默认值:对于不是 ConvTranspose{1,2,3}d 实例的模块,默认为0,而对于是这些实例的模块,默认为1

返回值

具有新参数化并注册到指定权重的原始模块

返回类型

Module

示例:

>>> snm = spectral_norm(nn.Linear(20, 40))
>>> snm
ParametrizedLinear(
  in_features=20, out_features=40, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): _SpectralNorm()
    )
  )
)
>>> torch.linalg.matrix_norm(snm.weight, 2)
tensor(1.0081, grad_fn=<AmaxBackward0>)
本页目录