torch.nn.utils.weight_norm
- torch.nn.utils.weight_norm(module, name='weight', dim=0)[源代码]
-
对给定模块中的参数进行权重规范化。
$\mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}$权重归一化是一种重新参数化的技术,它将权重张量的大小与其方向解耦。具体来说,它会用两个新的参数来替换原来的
name
(例如'weight'
):一个用于表示幅度(如'weight_g'
),另一个用于表示方向(如'weight_v'
)。权重归一化通过在每次调用forward()
之前,根据这两个参数重新计算出原始的权重张量来实现。默认情况下,当
dim=0
时,范数会为每个输出通道或平面单独计算。若要对整个权重张量计算范数,请设置dim=None
。参见 这篇文章
警告
此函数已弃用,请使用
torch.nn.utils.parametrizations.weight_norm()
,它采用现代参数化 API。新的weight_norm
与从旧的weight_norm
生成的state_dict
兼容。迁移指南:
-
权重的幅度(
weight_g
)和方向(weight_v
)现在分别表示为parametrizations.weight.original0
和parametrizations.weight.original1
。如果你对此感到困扰,请在GitHub 上留言。 -
要删除权重归一化的重参数化,请使用
torch.nn.utils.parametrize.remove_parametrizations()
。 -
权重不再是在模块的前向传播时一次性重新计算;而是每次访问时都会重新计算。要恢复旧的行为,请在调用相关模块之前使用
torch.nn.utils.parametrize.cached()
。
- 参数
- 返回值
-
具有权重范数挂钩的原始模块
- 返回类型
-
模块T
示例:
>>> m = weight_norm(nn.Linear(20, 40), name='weight') >>> m Linear(in_features=20, out_features=40, bias=True) >>> m.weight_g.size() torch.Size([40, 1]) >>> m.weight_v.size() torch.Size([40, 20])
-