RNNBase

torch.nn.RNNBase(mode, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0.0, bidirectional=False, proj_size=0, device=None, dtype=None)[源代码]

RNN模块(包括RNN、LSTM和GRU)的基础类。

实现了RNN、LSTM和GRU类共有的功能,如模块初始化和参数存储管理的实用方法。

注意

RNNBase 类未实现 forward 方法。

注意

LSTM 和 GRU 类覆盖了 RNNBase 实现的一些方法。

flatten_parameters()[源代码]

重置参数数据指针,使它们能够使用更高效的代码路径。

目前,这只有在模块位于 GPU 上并且启用了 cuDNN 时才有效,否则不会执行任何操作。

本页目录