Transformer解码器

classtorch.nn.TransformerDecoder(decoder_layer, num_layers, norm=None)[源代码]

TransformerDecoder 由 N 个解码层堆叠而成。

参数
  • decoder_layer (TransformerDecoderLayer) – 一个 TransformerDecoderLayer 类的实例(必填)。

  • num_layers (int) – 解码器中子解码层的数量(必填)。

  • norm (Optional[Module]) – 可选的层归一化组件。

示例:
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
>>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
>>> memory = torch.rand(10, 32, 512)
>>> tgt = torch.rand(20, 32, 512)
>>> out = transformer_decoder(tgt, memory)
forward(tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, tgt_is_causal=None, memory_is_causal=False)[源代码]

依次将输入(和掩码)传递给解码器层。

参数
  • tgt (Tensor) – 输入解码器的序列(必填)。

  • memory (Tensor) – 编码器最后一层产生的序列(必需)。

  • tgt_mask (Optional[Tensor]) – 目标序列的可选掩码。

  • memory_mask (Optional[Tensor]) – 记忆序列的可选掩码。

  • tgt_key_padding_mask (Optional[Tensor]) – 每批tgt键的掩码(可选)。

  • memory_key_padding_mask (Optional[Tensor]) – 每批的键掩码(可选)。

  • tgt_is_causal (Optional[bool]) – 如果指定,将应用因果掩码作为 tgt mask。默认值:None; 尝试自动检测因果掩码。警告:如果提供不正确的提示(如tgt_is_causal),可能会导致执行错误,并影响前向和后向兼容性。

  • memory_is_causal (bool) – 如果指定,则将因果掩码应用于 memory mask。默认值为False。警告:如果 memory_is_causal 设置不正确,可能会导致执行错误,并影响前向和后向兼容性。

返回类型

Tensor

形状:

参阅Transformer的文档。

本页目录