多头注意力机制
- 类torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[源代码]
-
允许模型同时从不同的表示子空间中获取信息。
论文中提出的方法:注意力机制全靠它。
多头注意力(Multi-Head Attention)的定义是:
$\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O$即$head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$。
nn.MultiHeadAttention
将会在可能的情况下使用优化版的scaled_dot_product_attention()
实现。除了支持新的
scaled_dot_product_attention()
函数外,为了加速推理,多头注意力机制(MHA)还将采用快速路径推理并支持嵌套张量,但前提是满足以下条件。-
自注意力机制正在进行计算(即,
query
、key
和value
使用同一个张量)。 -
输入是以批处理形式提供的(3D),且
batch_first == True
-
自动求梯度要么被禁用了(使用了
torch.inference_mode
或torch.no_grad
),要么传入的张量参数中没有设置requires_grad
-
训练已关闭(使用
.eval()
) -
add_bias_kv
的值为False
-
add_zero_attn
为False
-
kdim
和vdim
与embed_dim
相等 -
如果传入一个NestedTensor,则不会传递
key_padding_mask
和attn_mask
-
自动 casting 已关闭
如果启用了优化的推理快速路径实现,可以为
query
/key
/value
传递一个NestedTensor 来更高效地表示填充部分,而无需使用填充掩码。在这种情况下,将返回一个 NestedTensor,并且可以期望获得与输入中填充比例成正比的额外加速。- 参数
-
-
embed_dim - 模型的总的维度。
-
num_heads – 并行注意力头的数量。注意,
embed_dim
会分布在num_heads
个头上(即每个头的维度为embed_dim // num_heads
)。 -
dropout -
attn_output_weights
的 dropout 概率。默认值:0.0
(不使用 dropout)。 -
bias - 如果指定了偏置,它将被添加到输入和输出的投影层中。默认值为
True
。 -
add_bias_kv - 如果指定了此参数,则会在dim=0位置为键和值序列添加偏置。默认情况下,该参数的值为
False
。 -
add_zero_attn – 如果设置此参数,将在
dim=1
位置为键和值序列添加一个全零的新批次。默认值:False
。 -
kdim - 键的特征总数。默认值:
None
(即使用kdim=embed_dim
)。 -
vdim - 值的总特征数。默认为
None
(即使用vdim=embed_dim
)。 -
batch_first – 如果为
True
,则输入和输出张量的维度为 (批次, 序列, 特征)。默认值:False
(序列, 批次, 特征)。
-
示例:
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
- forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)[源代码]
-
利用查询、键和值嵌入来计算注意力输出。
支持可选参数,例如 padding、mask 和 attention 权重。
- 参数
-
-
query (Tensor) – 查询嵌入的形状为 $(L, E_q)$(未批量处理输入)或$(L, N, E_q)$(当
batch_first=False
时)或$(N, L, E_q)$(当batch_first=True
时),其中 $L$ 是目标序列长度,$N$ 是批量大小,而$E_q$是查询嵌入维度embed_dim
。查询与键值对进行比较以生成输出。更多详细信息请参阅“Attention Is All You Need”。 -
key (Tensor) – 键嵌入的形状为:
未批量处理时为 $(S, E_k)$,batch_first=False
时为 $(S, N, E_k)$,batch_first=True
时为 $(N, S, E_k)$。
其中 $S$ 表示源序列长度,$N$ 表示批量大小,而 $E_k$ 是键嵌入维度kdim
。详情请参阅“Attention Is All You Need”。 -
value (Tensor) – 未批量输入的值嵌入形状为$(S, E_v)$,当
batch_first=False
时形状为$(S, N, E_v)$,当batch_first=True
时形状为$(N, S, E_v)$。其中$S$表示源序列长度,$N$表示批量大小,而$E_v$是值嵌入维度vdim
。详情请参阅“Attention Is All You Need”。 -
key_padding_mask (Optional[Tensor]) – 如果指定了,一个形状为$(N, S)$的掩码,表示在注意力计算中要忽略
key
中的哪些元素(即视为“填充”)。对于未批处理的query,形状应为$(S)$。支持二进制和浮点掩码。对于二进制掩码,True
表示相应的key
值将被忽略用于注意力计算;对于浮点掩码,则直接加到相应的key
值上。 -
need_weights (bool) – 如果设置为
True
,除了返回attn_outputs
之外,还会返回attn_output_weights
。将need_weights
设为False
可以使用优化的scaled_dot_product_attention
来实现MHA的最佳性能。默认值:True
。 -
attn_mask (Optional[Tensor]) – 如果指定了,这是一个二维或三维掩码,用于阻止对某些位置的注意力。其形状可以是$(L, S)$ 或 $(N\cdot\text{num\_heads}, L, S)$,其中 $N$ 是批量大小,$L$ 是目标序列长度,而$S$ 是源序列长度。二维掩码会在整个批次中广播,三维掩码则允许为每个条目使用不同的掩码。支持二进制和浮点掩码:对于二进制掩码,
True
表示相应位置不允许进行注意力处理;对于浮点掩码,其值将加到注意力权重中。如果同时提供了 attn_mask 和 key_padding_mask,则它们的类型应该一致。 -
average_attn_weights (bool) – 如果为 true,表示返回的
attn_weights
应该在各个头之间进行平均。否则,attn_weights
将按每个头分别提供。注意,此标志仅在need_weights=True
时有效。默认值:True
(即在各个头之间平均权重)。 -
is_causal (bool) – 如果指定,则将因果掩码作为注意力掩码应用。默认值为
False
。警告:is_causal
提示attn_mask
是因果掩码。提供不正确的提示可能导致执行错误,包括前向和后向兼容性问题。
-
- 返回类型
- 输出:
-
-
attn_output - 注意力输出的形状为:
当输入未分批时,形状为$(L, E)$;
当batch_first=False
时,形状为$(L, N, E)$;
当batch_first=True
时,形状为$(N, L, E)$。
其中,$L$表示目标序列长度,$N$表示批量大小,而$E$是嵌入维度embed_dim
。 -
attn_output_weights - 仅在
need_weights=True
时返回。如果average_attn_weights=True
,当输入未批量处理时,返回形状为$(L, S)$的平均注意力权重;若输入已批量处理,则返回形状为$(N, L, S)$的平均注意力权重。其中N
表示批大小,L
为目标序列长度,而S
是源序列长度。如果average_attn_weights=False
,当输入未批量处理时返回形状为$(\text{num\_heads}, L, S)$的每个头注意力权重;若输入已批量处理,则返回形状为$(N, \text{num\_heads}, L, S)$的每个头注意力权重。
注意
对于未批量处理的输入,batch_first 参数会被忽略。
-
- merge_masks(attn_mask, key_padding_mask, query)[源代码]
-
确定掩码类型,并在需要时进行组合。
如果只提供一个掩码,则返回该掩码及其对应的掩码类型。如果同时提供了两个掩码,它们会被扩展到形状
(batch_size, num_heads, seq_len, seq_len)
,然后通过逻辑or
进行组合,并返回掩码类型 2。
:param attn_mask: 形状为(seq_len, seq_len)
的注意力掩码,掩码类型 0
:param key_padding_mask: 形状为(batch_size, seq_len)
的填充掩码,掩码类型 1
:param query: 形状为(batch_size, seq_len, embed_dim)
的查询嵌入- 返回值
-
合并掩码类型:表示合并掩码的类型,可以是 0、1 或 2。
- 返回类型
-
合并后的掩码
-