阈值
- classtorch.nn.Threshold(threshold, value, inplace=False)[源代码]
-
对输入张量的每个元素应用阈值处理。
阈值是指:
$y = \begin{cases} x, &\text{ if } x > \text{threshold} \\ \text{value}, &\text{ otherwise } \end{cases}$- 形状:
-
-
输入: $(*)$,其中$*$表示任意维度的数量。
-
输出: $(*)$,形状与输入相同。
-
示例:
>>> m = nn.Threshold(0.1, 20) >>> input = torch.randn(2) >>> output = m(input)