Transformer解码器
- classtorch.nn.TransformerDecoder(decoder_layer, num_layers, norm=None)[源代码]
-
TransformerDecoder 由 N 个解码层堆叠而成。
- 参数
-
-
decoder_layer (TransformerDecoderLayer) – 一个 TransformerDecoderLayer 类的实例(必填)。
-
num_layers (int) – 解码器中子解码层的数量(必填)。
-
- 示例:
-
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) >>> memory = torch.rand(10, 32, 512) >>> tgt = torch.rand(20, 32, 512) >>> out = transformer_decoder(tgt, memory)
- forward(tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, tgt_is_causal=None, memory_is_causal=False)[源代码]
-
依次将输入(和掩码)传递给解码器层。
- 参数
-
-
tgt (Tensor) – 输入解码器的序列(必填)。
-
memory (Tensor) – 编码器最后一层产生的序列(必需)。
-
tgt_is_causal (Optional[bool]) – 如果指定,将应用因果掩码作为
tgt mask
。默认值:None
; 尝试自动检测因果掩码。警告:如果提供不正确的提示(如tgt_is_causal
),可能会导致执行错误,并影响前向和后向兼容性。 -
memory_is_causal (bool) – 如果指定,则将因果掩码应用于
memory mask
。默认值为False
。警告:如果memory_is_causal
设置不正确,可能会导致执行错误,并影响前向和后向兼容性。
-
- 返回类型
- 形状:
-
参阅
Transformer
的文档。