torch.nn.utils.parametrizations.weight_norm
- torch.nn.utils.parametrizations.weight_norm(module, name='weight', dim=0)[源代码]
-
对给定模块中的参数进行权重规范化。
$\mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}$权重归一化是一种重新参数化的技术,它将权重张量的大小和方向分开。具体来说,它会用两个新的参数来替换由
name
指定的原始参数:一个用于指定幅度,另一个用于指定方向。默认情况下,当
dim=0
时,范数会为每个输出通道或平面单独计算。若要对整个权重张量计算范数,请设置dim=None
。参见 这篇文章
示例:
>>> m = weight_norm(nn.Linear(20, 40), name='weight') >>> m ParametrizedLinear( in_features=20, out_features=40, bias=True (parametrizations): ModuleDict( (weight): ParametrizationList( (0): _WeightNorm() ) ) ) >>> m.parametrizations.weight.original0.size() torch.Size([40, 1]) >>> m.parametrizations.weight.original1.size() torch.Size([40, 20])