RMSNorm
- 类torch.nn.RMSNorm(normalized_shape, eps=None, elementwise_affine=True, device=None, dtype=None)[源代码]
-
对一批输入数据进行均方根层规范化处理。
此层实现了均方根层归一化论文中描述的操作。
$y = \frac{x}{\sqrt{\mathrm{RMS}[x] + \epsilon}} * \gamma$均方根范数在最后一个
D
维度上进行计算,其中D
是normalized_shape
的维数。例如,如果normalized_shape
为(3, 5)
(一个二维形状),则均方根范数会在输入的最后两个维度上进行计算。- 参数
-
-
normalized_shape (int 或 list 或 torch.Size) –
输入形状不符合预期的输入大小
$[* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1] \times \ldots \times \text{normalized\_shape}[-1]]$如果使用单个整数,它会被视为一个包含这个整数的列表。此模块会在最后一个维度上进行归一化处理,而这个维度的大小应该与提供的整数值一致。
-
eps (Optional[float]) – 用于数值稳定性,添加到分母中的值。默认值:
torch.finfo(x.dtype).eps()
-
elementwise_affine (bool) – 一个布尔值,当设置为
True
时,此模块具有可学习的每个元素的仿射参数。权重初始化为1,偏差初始化为0。默认值:True
。
-
- 形状:
-
-
输入: $(N, *)$
-
输出: $(N, *)$ (与输入的形状相同)
-
示例:
>>> rms_norm = nn.RMSNorm([2, 3]) >>> input = torch.randn(2, 2, 3) >>> rms_norm(input)
- reset_parameters()[源代码]
-
基于在 __init__ 中的初始化设置来重置参数。