torch.nn.functional.scaled_dot_product_attention
- torch.nn.functional.scaled_dot_product_attention()>
-
- scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0)
-
is_causal=False, scale=None, enable_gqa=False) -> Tensor:
根据查询、键和值张量计算缩放点积注意力。如果有提供可选的注意力掩码,则会使用它;如果指定了大于0.0的概率,则会应用 dropout。可选的 scale 参数必须通过关键字传递。
# Efficient implementation equivalent to the following: def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor: L, S = query.size(-2), key.size(-2) scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale attn_bias = torch.zeros(L, S, dtype=query.dtype) if is_causal: assert attn_mask is None temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) else: attn_bias += attn_mask if enable_gqa: key = key.repeat_interleave(query.size(-3)//key.size(-3), -3) value = value.repeat_interleave(query.size(-3)//value.size(-3), -3) attn_weight = query @ key.transpose(-2, -1) * scale_factor attn_weight += attn_bias attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) return attn_weight @ value
警告
此功能为测试版,可能随时更改。
警告
此函数始终根据指定的
dropout_p
参数应用 dropout。为了在评估期间禁用 dropout,请确保在调用该函数的模块处于非训练模式时传递值0.0
。例如:
class MyModel(nn.Module): def __init__(self, p=0.5): super().__init__() self.p = p def forward(self, ...): return F.scaled_dot_product_attention(..., dropout_p=(self.p if self.training else 0.0))
注意
目前支持三种缩放点积注意力的实现方式:
-
在C++中定义的与上述公式匹配的PyTorch实现
当使用 CUDA 后端时,该函数会调用优化的内核来提升性能。而对于其他所有后端,则使用 PyTorch 的实现。
所有实现默认都是启用的。缩放点积注意力会根据输入自动选择最优的实现方式。为了更精细地控制所选的实现方式,提供了一些函数用于启用工和禁用不同的实现方式。推荐使用上下文管理器来实现这一点。
-
torch.nn.attention.sdpa_kernel()
:一个用于启用或禁用特定实现的上下文管理器。 -
torch.backends.cuda.enable_flash_sdp()
:全局启用或禁用 FlashAttention。 -
torch.backends.cuda.enable_mem_efficient_sdp()
: 全局开启或关闭内存高效的注意力机制。 -
torch.backends.cuda.enable_math_sdp()
: 全局启用或禁用 PyTorch 的 C++ 实现。
每个融合内核都有特定的输入限制。如果用户需要使用特定的融合实现,可以使用
torch.nn.attention.sdpa_kernel()
来禁用 PyTorch C++ 实现。如果没有可用的融合实现,则会引发警告,说明无法运行该融合实现的原因。由于融合浮点运算的特性,此函数的输出会根据所选的后端内核有所不同。C++ 实现支持 torch.float64,在需要更高精度时可以使用。对于数学后端,如果输入为 torch.half 或 torch.bfloat16,则所有中间结果都保持在 torch.float。
了解更多信息,请参见数值精度
组查询注意力(GQA)是一项实验性功能。当前,它仅在CUDA张量上支持Flash_attention和数学内核,而不支持嵌套张量。GQA的限制条件如下:
-
number_of_heads_query % number_of_heads_key_value == 0,
-
number_of_heads_key 等于 number_of_heads_value
注意
在某些情况下,当张量位于 CUDA 设备上并使用 CuDNN 时,此操作符可能会选择一个非确定性算法以提高性能。如果你不希望这样,可以通过将
torch.backends.cudnn.deterministic = True
设置为True
来使操作具有确定性(这可能会影响性能)。有关更多信息,请参阅可重复性。- 参数
-
-
query (Tensor) – 查询张量,形状为 $(N, ..., Hq, L, E)$。
-
key (Tensor) – 关键张量;其形状为 $(N, ..., H, S, E)$。
-
value (Tensor) – 值张量;其形状为 $(N, ..., H, S, Ev)$。
-
attn_mask (可选 Tensor) – 注意力掩码;其形状必须可以广播到注意力权重的形状 $(N,..., L, S)$。支持两种类型的掩码:一种是布尔类型,值为 True 表示该元素应参与注意力计算;另一种与 query、key 和 value 相同类型(均为浮点型)的掩码,会被添加到注意力分数中。
-
dropout_p (float) – 控制Dropout的浮点数;若其值大于0.0,则会应用Dropout。
-
is_causal (bool) – 当设置为 true 时,如果掩码是方矩阵,则注意力掩码是一个下三角矩阵。 如果掩码是非方矩阵,则由于对齐原因(参见
torch.nn.attention.bias.CausalBias
),注意力掩码具有左上角的因果偏差形式。 如果同时设置了 attn_mask 和 is_causal,则会抛出错误。 -
scale (可选 python:float, 关键字参数) – 应用softmax之前的缩放因子。如果为 None,则默认值为 $\frac{1}{\sqrt{E}}$。
-
enable_gqa (bool) – 设置为 True 时,将启用分组查询注意力(GQA),默认情况下该值为 False。
-
- 返回值
-
注意力输出的形状为$(N, ..., Hq, L, Ev)$。
- 返回类型
-
输出(Tensor)
- 图形说明:
-
-
$N: \text{批量大小} ... : \text{其他任意数量的批量维度(可选)}$
-
$S: \text{序列的长度}$
-
$L: \text{目标序列的长度}$
-
$E: \text{查询和键的嵌入维度}$
-
$Ev: \text{值的维度}$
-
$Hq: \text{查询头部的数量}$
-
$H: \text{键和值的数量}$
-
示例
>>> # Optionally use the context manager to ensure one of the fused kernels is run >>> query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") >>> with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): >>> F.scaled_dot_product_attention(query,key,value)
>>> # Sample for GQA for llama3 >>> query = torch.rand(32, 32, 128, 64, dtype=torch.float16, device="cuda") >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") >>> with sdpa_kernel(backends=[SDPBackend.MATH]): >>> F.scaled_dot_product_attention(query,key,value,enable_gqa=True)