torch.nn.utils.parametrizations.weight_norm

torch.nn.utils.parametrizations.weight_norm(module, name='weight', dim=0)[源代码]

对给定模块中的参数进行权重规范化。

$\mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}$

权重归一化是一种重新参数化的技术,它将权重张量的大小和方向分开。具体来说,它会用两个新的参数来替换由 name 指定的原始参数:一个用于指定幅度,另一个用于指定方向。

默认情况下,当dim=0时,范数会为每个输出通道或平面单独计算。若要对整个权重张量计算范数,请设置dim=None

参见 这篇文章

参数
  • module (Module) - 包含该组件的模块

  • name (str, 可选) – 权重参数的名称

  • dim (int, 可选) – 用于计算范数的维度

返回值

具有权重范数挂钩的原始模块

示例:

>>> m = weight_norm(nn.Linear(20, 40), name='weight')
>>> m
ParametrizedLinear(
  in_features=20, out_features=40, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): _WeightNorm()
    )
  )
)
>>> m.parametrizations.weight.original0.size()
torch.Size([40, 1])
>>> m.parametrizations.weight.original1.size()
torch.Size([40, 20])
本页目录