RMSNorm
- 类torch.nn.modules.normalization.RMSNorm(normalized_shape, eps=None, elementwise_affine=True, device=None, dtype=None)[源代码]
-
对一批输入数据进行均方根层规范化处理。
此层实现了均方根层归一化论文中描述的操作。
均方根范数在最后一个
D
维度上进行计算,其中D
是normalized_shape
的维数。例如,如果normalized_shape
为(3, 5)
(一个二维形状),则均方根范数会在输入的最后两个维度上进行计算。- 参数
- 形状:
-
-
输入:
-
输出: (与输入的形状相同)
-
示例:
>>> rms_norm = nn.RMSNorm([2, 3]) >>> input = torch.randn(2, 2, 3) >>> rms_norm(input)
- reset_parameters()[源代码]
-
基于在 __init__ 中的初始化设置来重置参数。