RNNBase
- 类torch.nn.RNNBase(mode, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0.0, bidirectional=False, proj_size=0, device=None, dtype=None)[源代码]
-
RNN模块(包括RNN、LSTM和GRU)的基础类。
实现了RNN、LSTM和GRU类共有的功能,如模块初始化和参数存储管理的实用方法。
注意
RNNBase 类未实现 forward 方法。
注意
LSTM 和 GRU 类覆盖了 RNNBase 实现的一些方法。
- flatten_parameters()[源代码]
-
重置参数数据指针,使它们能够使用更高效的代码路径。
目前,这只有在模块位于 GPU 上并且启用了 cuDNN 时才有效,否则不会执行任何操作。