torch.nn.utils.spectral_norm
- torch.nn.utils.spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None)[源代码]
-
对给定模块中的参数应用光谱规范化。
$\mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}$谱归一化通过使用幂迭代方法计算权重矩阵的谱范数$\sigma$,并在每次
forward()
调用之前重新缩放权重张量,从而稳定生成对抗网络(GAN)中判别器的训练。如果权重张量的维度大于2,则在幂迭代方法中将其重塑为2D以计算谱范数。- 参数
- 返回值
-
具有谱范数挂钩的原始模块
- 返回类型
-
模块T
注意
此功能已通过
torch.nn.utils.parametrizations.spectral_norm()
使用新的参数化功能在torch.nn.utils.parametrize.register_parametrization()
中重新实现。请使用更新后的版本,此函数将在未来的 PyTorch 版本中被弃用。示例:
>>> m = spectral_norm(nn.Linear(20, 40)) >>> m Linear(in_features=20, out_features=40, bias=True) >>> m.weight_u.size() torch.Size([40])