GRUCell
- 类torch.nn.GRUCell(input_size, hidden_size, bias=True, device=None, dtype=None)[源代码]
-
一个门控循环单元(GRU)。
$\begin{array}{ll} r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\ z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\ n = \tanh(W_{in} x + b_{in} + r \odot (W_{hn} h + b_{hn})) \\ h' = (1 - z) \odot n + z \odot h \end{array}$其中$\sigma$ 是 sigmoid 函数,$\odot$ 表示 Hadamard 乘积。
- 参数
- 输入类型:input, hidden
-
-
input : 含有输入特征的张量
-
hidden:表示批次中每个元素的初始隐藏状态的张量。如果没有提供,默认值为零。
-
- 输出: h'
-
-
h':包含批次中每个元素下一个隐藏状态的张量
-
- 形状:
-
-
输入: $(N, H_{in})$ 或 $(H_{in})$ 张量,包含输入特征。其中 $H_{in}$ 等于 input_size。
-
hidden: $(N, H_{out})$ 或 $(H_{out})$ 张量,包含初始隐藏状态。其中 $H_{out}$ 等于 hidden_size。如果没有提供,默认值为零。
-
输出: $(N, H_{out})$ 或 $(H_{out})$ 形式的张量,包含下一个隐藏状态。
-
- 变量
-
-
weight_ih (torch.Tensor) – 输入到隐藏层的可学习权重,其形状为(3*hidden_size, input_size)
-
weight_hh (torch.Tensor) – 隐藏层到隐藏层的可学习权重,形状为(3*hidden_size, hidden_size)
-
bias_ih - 输入到隐藏层的可学习偏置,其形状为(3*hidden_size)
-
bias_hh – 隐藏层到隐藏层的可学习偏置项,其形状为(3*hidden_size)
-
注意
所有的权重和偏置都从$\mathcal{U}(-\sqrt{k}, \sqrt{k})$ 初始化,其中 $k = \frac{1}{\text{hidden\_size}}$。
在某些ROCm设备上,当使用float16输入时,此模块会采用不同的精度进行反向传播。
示例:
>>> rnn = nn.GRUCell(10, 20) >>> input = torch.randn(6, 3, 10) >>> hx = torch.randn(3, 20) >>> output = [] >>> for i in range(6): ... hx = rnn(input[i], hx) ... output.append(hx)