RMSNorm

torch.nn.RMSNorm(normalized_shape, eps=None, elementwise_affine=True, device=None, dtype=None)[源代码]

对一批输入数据进行均方根层规范化处理。

此层实现了均方根层归一化论文中描述的操作。

$y = \frac{x}{\sqrt{\mathrm{RMS}[x] + \epsilon}} * \gamma$

均方根范数在最后一个D维度上进行计算,其中Dnormalized_shape的维数。例如,如果normalized_shape(3, 5)(一个二维形状),则均方根范数会在输入的最后两个维度上进行计算。

参数
  • normalized_shape (intlisttorch.Size) –

    输入形状不符合预期的输入大小

    $[* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1] \times \ldots \times \text{normalized\_shape}[-1]]$

    如果使用单个整数,它会被视为一个包含这个整数的列表。此模块会在最后一个维度上进行归一化处理,而这个维度的大小应该与提供的整数值一致。

  • eps (Optional[float]) – 用于数值稳定性,添加到分母中的值。默认值: torch.finfo(x.dtype).eps()

  • elementwise_affine (bool) – 一个布尔值,当设置为True时,此模块具有可学习的每个元素的仿射参数。权重初始化为1,偏差初始化为0。默认值: True

形状:
  • 输入: $(N, *)$

  • 输出: $(N, *)$ (与输入的形状相同)

示例:

>>> rms_norm = nn.RMSNorm([2, 3])
>>> input = torch.randn(2, 2, 3)
>>> rms_norm(input)
extra_repr()[源代码]

模块的额外信息。

返回类型

str

forward(x)[源代码]

进行前向计算。

返回类型

Tensor

reset_parameters()[源代码]

基于在 __init__ 中的初始化设置来重置参数。

本页目录