GLU

torch.nn.GLU(dim=-1)[源代码]

应用 gated 线性单元函数。

${GLU}(a, b)= a \otimes \sigma(b)$ 其中$a$ 表示输入矩阵的第一半,$b$ 表示第二半。

参数

dim (int) – 拆分输入的维度。默认值:-1

形状:
  • 输入: $(\ast_1, N, \ast_2)$,其中 * 表示任意数量的额外维度。

  • 输出为 $(\ast_1, M, \ast_2)$,其中 $M=N/2$

示例:

>>> m = nn.GLU()
>>> input = torch.randn(4, 2)
>>> output = m(input)
本页目录