多头注意力机制

torch.ao.nn.quantizable.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[源代码]
dequantize()[源代码]

用于将量化后的MHA转换回浮点数的工具。

动机是,将量化版本中的权重格式转换回浮点格式并不简单。

forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)[源代码]
注意:

请参阅forward()以获取更多信息。

参数
  • query (Tensor) – 将查询和一组键值对映射到输出。详情请参阅“Attention Is All You Need”。

  • key (Tensor) – 将查询和一组键值对映射到输出。更多细节请参阅“Attention Is All You Need”。

  • value (Tensor) – 将一个查询和一组键值对映射到输出。详情请参阅《Attention Is All You Need》。

  • key_padding_mask (Optional[Tensor]) – 如果提供,指定的填充元素将在密钥中被忽略。当给定一个二进制掩码时,如果值为 True,则注意力层上对应的值将被忽略。

    key_padding_mask (Optional[Tensor]) – 如果提供,指定填充元素在密钥中将被忽略。当给定一个二进制掩码时,如果值为 True,则注意力层上对应的值将被忽略。

  • need_weights (bool) – 是否输出 attn_output_weights。

  • attn_mask (Optional[Tensor]) – 2D 或 3D 掩码,用于防止对某些位置的注意力。2D 掩码将在所有批次中广播,而 3D 掩码允许为每个批次的不同条目指定不同的掩码。

返回类型

Tuple[Tensor, Optional[Tensor]]

形状:
  • 输入:

  • query: (L,N,E)(L, N, E) 其中 L 表示目标序列长度,N 表示批量大小,E 表示嵌入维度。如果 batch_firstTrue,则格式变为 (N,L,E)(N, L, E)

  • key: (S,N,E)(S, N, E),其中 S 表示源序列长度,N 表示批量大小,E 表示嵌入维度。如果 batch_firstTrue,则 key 变为 (N,S,E)(N, S, E)

  • 值: (S,N,E)(S, N, E),其中 S 表示源序列长度,N 表示批量大小,E 表示嵌入维度。如果 batch_firstTrue,则值为(N,S,E)(N, S, E)

  • key_padding_mask: (N,S)(N, S) 其中 N 表示批量大小,S 表示源序列长度。如果提供了一个 BoolTensor,值为 True 的位置将被忽略,而值为 False 的位置保持不变。

  • attn_mask: 2D 掩码 (L,S)(L, S),其中 L 是目标序列长度,S 是源序列长度。3D 掩码 (N×num_heads,L,S)(N \times num\_heads, L, S),其中 N 是批量大小,L 是目标序列长度,S 是源序列长度。attn_mask 确保位置 i 只能关注未屏蔽的位置。如果提供的是 BoolTensor,则值为 True 的位置不允许关注,而值为 False 的位置保持不变。如果提供的是 FloatTensor,则会加到注意力权重上。

  • is_causal: 如果指定了此参数,则会使用因果掩码作为注意力掩码,并且与提供 attn_mask 参数相互排斥。默认情况下,该参数的值为 False

  • average_attn_weights: 如果为 true,则表示返回的 attn_weights 应该在各个头之间进行平均计算;否则,attn_weights 将按每个头分别提供。需要注意的是,此标志仅在 need_weights=True. 时生效。默认值:True(即在各头之间平均权重)。

  • 输出结果:

  • attn_output: (L,N,E)(L, N, E) 其中 L 表示目标序列长度,N 表示批量大小,E 表示嵌入维度。如果 batch_firstTrue,则形状变为 (N,L,E)(N, L, E)

  • attn_output_weights: 如果 average_attn_weights=True,返回形状为 (N,L,S)(N, L, S) 的平均注意力权重,其中 N 是批次大小,L 是目标序列长度,S 是源序列长度。如果 average_attn_weights=False,则返回每个头的注意力权重,形状为 (N,numheads,L,S)(N, num_heads, L, S)

本页目录