Hardshrink

torch.nn.Hardshrink(lambd=0.5)[源代码]

按元素应用硬收缩(Hardshrink)函数。

Hardshrink 的定义是:

$\text{HardShrink}(x) = \begin{cases} x, & \text{ if } x > \lambda \\ x, & \text{ if } x < -\lambda \\ 0, & \text{ otherwise } \end{cases}$
参数

lambd (float) – $\lambda$ 值,用于 Hardshrink 公式。默认值:0.5

形状:
  • 输入: $(*)$,其中$*$表示任意维度的数量。

  • 输出: $(*)$,形状与输入相同。

{BASE_RAW_UPLOAD_URL}/pytorch-doc-2.5/955b766439e8a312baa729ee3b76c621.png

示例:

>>> m = nn.Hardshrink()
>>> input = torch.randn(2)
>>> output = m(input)
本页目录