PyTorch 入门指南
学习 PyTorch
图像和视频
音频
后端
强化学习
在生产环境中部署 PyTorch 模型
Profiling PyTorch
代码变换与FX
前端API
扩展 PyTorch
模型优化
并行和分布式训练
边缘端的 ExecuTorch
推荐系统
多模态

(测试版)使用缩放点积注意力(SDPA)实现高性能 Transformer

作者: Driss Guessous

概述

在本教程中,我们想重点介绍一个新的 torch.nn.functional 函数,它可能对实现 Transformer 架构有所帮助。该函数名为 torch.nn.functional.scaled_dot_product_attention。有关该函数的详细描述,请参阅 PyTorch 文档。该函数已被纳入 torch.nn.MultiheadAttentiontorch.nn.TransformerEncoderLayer 中。

概述

在高层面上,这个 PyTorch 函数根据论文 Attention is all you need 中的定义,计算查询(query)、键(key)和值(value)之间的缩放点积注意力(SDPA)。虽然可以使用现有的 PyTorch 函数来实现这个功能,但一个融合的实现相比简单的实现能够带来显著的性能提升。

融合实现

对于 CUDA 张量输入,该函数将调用以下实现之一:

本教程需要 PyTorch 2.0.0 或更高版本。

importtorch
importtorch.nnasnn
importtorch.nn.functionalasF
device = "cuda" if torch.cuda.is_available() else "cpu"

# Example Usage:
query, key, value = torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device)
F.scaled_dot_product_attention(query, key, value)
tensor([[[-1.3321, -0.3489,  0.3015, -0.3912,  0.9867,  0.3137, -0.0691,
          *1.2593],
         [-1.0882,  0.2506,  0.6491,  0.1360,  0.5238, -0.2448, -0.0820,
          *0.6171],
         [-1.0012,  0.3990,  0.6441, -0.0277,  0.5325, -0.2564, -0.0607,
          *0.6404]],

        [[ 0.6091,  0.0708,  0.6188,  0.3252, -0.1598,  0.4197, -0.2335,
           0.0630],
         [ 0.5285,  0.3890, -0.2649,  0.3706, -0.3839,  0.1963, -0.6242,
           0.2312],
         [ 0.4048,  0.0762,  0.3777,  0.4689, -0.2978,  0.2754, -0.6429,
           0.1037]]], device='cuda:0')

显式调度控制

虽然该函数会隐式地选择三种实现之一进行调度,但用户也可以通过使用上下文管理器来显式控制调度。该上下文管理器允许用户显式禁用某些实现。如果用户想要确保函数确实针对其特定输入使用了最快的实现,可以使用上下文管理器来遍历并测量性能。

# Lets define a helpful benchmarking function:
importtorch.utils.benchmarkasbenchmark
defbenchmark_torch_function_in_microseconds(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return t0.blocked_autorange().mean * 1e6

# Lets define the hyper-parameters of our input
batch_size = 32
max_sequence_len = 1024
num_heads = 32
embed_dimension = 32

dtype = torch.float16

query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)

print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention,query,key,value):.3f} microseconds")

# Lets explore the speed of each of the 3 implementations
fromtorch.nn.attentionimport SDPBackend, sdpa_kernel


with sdpa_kernel(SDPBackend.MATH):
    math_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
    print(f"The math implementation runs in {math_time:.3f} microseconds")

with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    try:
        flash_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
        print(f"The flash attention implementation runs in {flash_time:.3f} microseconds")
    except RuntimeError:
        print("FlashAttention is not supported. See warnings for reasons.")

with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
    try:
        efficient_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
        print(f"The memory efficient implementation runs in {efficient_time:.3f} microseconds")
    except RuntimeError:
        print("EfficientAttention is not supported. See warnings for reasons.")
The default implementation runs in 2327.164 microseconds
The math implementation runs in 87046.244 microseconds
The flash attention implementation runs in 2334.455 microseconds
The memory efficient implementation runs in 4344.818 microseconds

硬件依赖性

根据您运行上述代码的机器以及可用硬件的不同,您的结果可能会有所不同。

  • 如果您没有 GPU 并且在 CPU 上运行,那么使用 FP32 时上下文管理器将不会产生任何影响,三次运行的结果应该相似。
  • 根据您的显卡支持的计算能力,flash attention 或 memory efficient 可能会失败。

因果自注意力机制

以下是一个受 Andrej Karpathy NanoGPT 仓库启发的多头因果自注意力块的示例实现。

classCausalSelfAttention(nn.Module):

    def__init__(self, num_heads: int, embed_dimension: int, bias: bool=False, is_causal: bool=False, dropout:float=0.0):
        super().__init__()
        assert embed_dimension % num_heads == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(embed_dimension, 3 * embed_dimension, bias=bias)
        # output projection
        self.c_proj = nn.Linear(embed_dimension, embed_dimension, bias=bias)
        # regularization
        self.dropout = dropout
        self.resid_dropout = nn.Dropout(dropout)
        self.num_heads = num_heads
        self.embed_dimension = embed_dimension
        # Perform causal masking
        self.is_causal = is_causal

    defforward(self, x):
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        query_projected = self.c_attn(x)

        batch_size = query_projected.size(0)
        embed_dim = query_projected.size(2)
        head_dim = embed_dim // (self.num_heads * 3)

        query, key, value = query_projected.chunk(3, -1)
        query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)

        if self.training:
            dropout = self.dropout
            is_causal = self.is_causal
        else:
            dropout = 0.0
            is_causal = False

        y = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=dropout, is_causal=is_causal)
        y = y.transpose(1, 2).view(batch_size, -1, self.num_heads * head_dim)

        y = self.resid_dropout(self.c_proj(y))
        return y


num_heads = 8
heads_per_dim = 64
embed_dimension = num_heads * heads_per_dim
dtype = torch.float16
model = CausalSelfAttention(num_heads=num_heads, embed_dimension=embed_dimension, bias=False, is_causal=True, dropout=0.1).to("cuda").to(dtype).eval()
print(model)
CausalSelfAttention(
  (c_attn): Linear(in_features=512, out_features=1536, bias=False)
  (c_proj): Linear(in_features=512, out_features=512, bias=False)
  (resid_dropout): Dropout(p=0.1, inplace=False)
)

NestedTensor 和密集张量的支持

SDPA 支持 NestedTensor 和 Dense tensor 输入。NestedTensors 处理输入为一批可变长度序列的情况,而无需将每个序列填充到批次中的最大长度。有关 NestedTensors 的更多信息,请参阅 torch.nestedNestedTensors 教程

importrandom
defgenerate_rand_batch(
    batch_size,
    max_sequence_len,
    embed_dimension,
    pad_percentage=None,
    dtype=torch.float16,
    device="cuda",
):
    if not pad_percentage:
        return (
            torch.randn(
                batch_size,
                max_sequence_len,
                embed_dimension,
                dtype=dtype,
                device=device,
            ),
            None,
        )
    # Random sequence lengths
    seq_len_list = [
        int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01)))
        for _ in range(batch_size)
    ]
    # Make random entry in the batch have max sequence length
    seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len
    return (
        torch.nested.nested_tensor(
            [
                torch.randn(seq_len, embed_dimension,
                            dtype=dtype, device=device)
                for seq_len in seq_len_list
            ]
        ),
        seq_len_list,
    )

random_nt, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=0.5, dtype=dtype, device=device)
random_dense, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=None, dtype=dtype, device=device)

# Currently the fused implementations don't support ``NestedTensor`` for training
model.eval()

with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    try:
        print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model,random_nt):.3f} microseconds")
        print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model,random_dense):.3f} microseconds")
    except RuntimeError:
        print("FlashAttention is not supported. See warnings for reasons.")
/usr/local/lib/python3.10/dist-packages/torch/nested/__init__.py:228: UserWarning:

The PyTorch API of nested tensors is in prototype stage and will change in the near future. We recommend specifying layout=torch.jagged when constructing a nested tensor, as this layout receives active development, has better operator coverage, and works with torch.compile. (Triggered internally at /pytorch/aten/src/ATen/NestedTensorImpl.cpp:178.)

Random NT runs in 565.292 microseconds
Random Dense runs in 947.667 microseconds

将 SDPA 与 torch.compile 结合使用

随着 PyTorch 2.0 的发布,引入了一个名为 torch.compile() 的新特性,它可以显著提升性能,超越 eager 模式。缩放点积注意力机制(scaled dot product attention)与 torch.compile() 完全兼容。为了演示这一点,让我们使用 torch.compile() 编译 CausalSelfAttention 模块,并观察由此带来的性能提升。

batch_size = 32
max_sequence_len = 256
x = torch.rand(batch_size, max_sequence_len,
               embed_dimension, device=device, dtype=dtype)
print(
    f"The non compiled module runs in  {benchmark_torch_function_in_microseconds(model,x):.3f} microseconds")


compiled_model = torch.compile(model)
# Let's compile it
compiled_model(x)
print(
    f"The compiled module runs in  {benchmark_torch_function_in_microseconds(compiled_model,x):.3f} microseconds")
The non compiled module runs in  416.026 microseconds
The compiled module runs in  517.141 microseconds

具体执行时间取决于机器,但我的测试结果如下:未编译模块运行时间为166.616微秒,已编译模块运行时间为166.726微秒。这并不是我们所预期的结果。让我们进一步深入分析。PyTorch自带了一个非常强大的内置分析器,您可以使用它来检查代码的性能特征。

fromtorch.profilerimport profile, record_function, ProfilerActivity
activities = [ProfilerActivity.CPU]
if device == 'cuda':
    activities.append(ProfilerActivity.CUDA)

with profile(activities=activities, record_shapes=False) as prof:
    with record_function(" Non-Compilied Causal Attention"):
        for _ in range(25):
            model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))


with profile(activities=activities, record_shapes=False) as prof:
    with record_function("Compiled Causal Attention"):
        for _ in range(25):
            compiled_model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

# For even more insights, you can export the trace and use ``chrome://tracing`` to view the results
#
# .. code-block:: python
#
#    prof.export_chrome_trace("compiled_causal_attention_trace.json").
*------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
*------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                         Non-Compilied Causal Attention         0.00%       0.000us         0.00%       0.000us       0.000us      10.537ms       101.58%      10.537ms      10.537ms             1
                         Non-Compilied Causal Attention        20.52%       2.265ms        77.21%       8.521ms       8.521ms       0.000us         0.00%      10.373ms      10.373ms             1
                                           aten::linear         1.17%     129.613us        28.65%       3.162ms      63.236us       0.000us         0.00%       7.767ms     155.333us            50
                                           aten::matmul         2.43%     268.403us        24.54%       2.708ms      54.153us       0.000us         0.00%       7.767ms     155.333us            50
                                               aten::mm        15.35%       1.694ms        19.77%       2.182ms      43.639us       7.767ms        74.87%       7.767ms     155.333us            50
         ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_tn         0.00%       0.000us         0.00%       0.000us       0.000us       5.566ms        53.65%       5.566ms     222.628us            25
                     aten::scaled_dot_product_attention         2.01%     221.551us        18.30%       2.020ms      80.800us       0.000us         0.00%       2.607ms     104.261us            25
              aten::_scaled_dot_product_flash_attention         3.08%     340.394us        16.30%       1.798ms      71.938us       0.000us         0.00%       2.607ms     104.261us            25
                         aten::_flash_attention_forward         3.62%     399.344us        11.44%       1.262ms      50.496us       2.607ms        25.13%       2.607ms     104.261us            25
void pytorch_flash::flash_fwd_kernel<pytorch_flash::...         0.00%       0.000us         0.00%       0.000us       0.000us       2.607ms        25.13%       2.607ms     104.261us            25
*------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 11.036ms
Self CUDA time total: 10.373ms

*------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
*------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                              Compiled Causal Attention         0.00%       0.000us         0.00%       0.000us       0.000us      10.496ms       101.10%      10.496ms      10.496ms             1
                              Compiled Causal Attention         9.52%       1.061ms        78.15%       8.706ms       8.706ms       0.000us         0.00%      10.382ms      10.382ms             1
                             Torch-Compiled Region: 2/0         8.70%     969.588us        66.49%       7.407ms     296.283us       0.000us         0.00%      10.382ms     415.288us            25
                                       CompiledFunction        27.45%       3.058ms        57.79%       6.437ms     257.499us       0.000us         0.00%      10.382ms     415.288us            25
                                               aten::mm         9.95%       1.108ms        15.08%       1.680ms      33.596us       7.766ms        74.80%       7.766ms     155.315us            50
         ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_tn         0.00%       0.000us         0.00%       0.000us       0.000us       5.563ms        53.58%       5.563ms     222.509us            25
              aten::_scaled_dot_product_flash_attention         2.29%     254.914us        15.26%       1.700ms      67.991us       0.000us         0.00%       2.616ms     104.657us            25
                         aten::_flash_attention_forward         3.69%     411.523us        11.15%       1.242ms      49.669us       2.616ms        25.20%       2.616ms     104.657us            25
void pytorch_flash::flash_fwd_kernel<pytorch_flash::...         0.00%       0.000us         0.00%       0.000us       0.000us       2.616ms        25.20%       2.616ms     104.657us            25
ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_stages_3...         0.00%       0.000us         0.00%       0.000us       0.000us       2.203ms        21.22%       2.203ms      88.122us            25
*------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 11.140ms
Self CUDA time total: 10.382ms

上述代码片段生成了一个报告,展示了在编译和非编译模块中消耗 GPU 执行时间最多的前 10 个 PyTorch 函数。分析表明,两个模块在 GPU 上花费的大部分时间都集中在同一组函数上。原因在于,torch.compile 非常擅长消除与 PyTorch 相关的框架开销。如果您的模型启动了高效的大型 CUDA 内核(在本例中为 CausalSelfAttention),那么 PyTorch 的开销就可以被掩盖。

实际上,您的模块通常并不只包含一个单一的 CausalSelfAttention 模块。在实验 Andrej Karpathy NanoGPT 仓库时,编译该模块将每个训练步骤的时间从 6090.49ms 缩短到了 3273.17ms!这是在提交 ae3a8d5 上进行的,该提交是在 Shakespeare 数据集上训练 NanoGPT 的版本。

在 attn_bias 子类中使用 SDPA

# As of PyTorch 2.3, we have added a new submodule that contains tensor subclasses.
# Designed to be used with ``torch.nn.functional.scaled_dot_product_attention``.
# The module is named ``torch.nn.attention.bias`` and contains the following two
# utilities for generating causal attention variants:
#
# - ``torch.nn.attention.bias.causal_upper_left``
# - ``torch.nn.attention.bias.causal_lower_right``
#
# .. note::
#    The current argument ``is_causal`` in ``torch.nn.functional.scaled_dot_product_attention``
#    is the same as using ``torch.nn.attention.bias.causal_upper_left``.
#

fromtorch.nn.attention.biasimport causal_lower_right, causal_upper_left

batch_size = 32
sequence_length_q = 2
sequence_length_kv = 10
num_heads = 16
embed_dimension = 32

dtype = torch.float16

query = torch.rand(batch_size, num_heads, sequence_length_q, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)

upper_left_bias = causal_upper_left(sequence_length_q, sequence_length_kv)
lower_right_bias = causal_lower_right(sequence_length_q, sequence_length_kv)

print(type(upper_left_bias))
print(type(lower_right_bias))

assert type(upper_left_bias) == type(lower_right_bias)
assert issubclass(type(upper_left_bias), torch.Tensor)

# As you can see from the previous output, are the same type ``torch.nn.attention.bias.CausalBias``
# and subclass ``torch.Tensor``

# Lets see what these tensors look like
print(upper_left_bias)
print(lower_right_bias)

# Upper Left Bias aligns the causal attention mask to the upper left corner of the attention scores matrix.
# This only has an impact when the attention scores matrix is not square, which is common for decoding use cases.
# Another way of thinking about this concept is that when you use upper left bias,
# the 0th token in the query is aligned to the 0th token in the key, while for lower right bias,
# Assuming the attention score matrix is two dimensional, ``attn_score[0][0]`` is the attention score
# between the 0th token in the query and the 0th token in the key.
# For lower right bias, the sequence of q is aligned so that the last token in q is aligned to the last token in k
# (for example, ``attn_score[-1][-1])`` is all True since the last token in q is at the same position as the last token in k
# even if the sequence length of q and k are different.

# These objects are intended to be used with sdpa
out_upper_left = F.scaled_dot_product_attention(query, key, value, upper_left_bias)
out_lower_right = F.scaled_dot_product_attention(query, key, value, lower_right_bias)
out_is_causal = F.scaled_dot_product_attention(query, key, value, is_causal=True)

assert torch.allclose(out_upper_left, out_is_causal)
assert not torch.allclose(out_upper_left, out_lower_right)

# These attention biases should also be compatible with torch.compile
compiled_sdpa = torch.compile(F.scaled_dot_product_attention, fullgraph=True)
out_upper_left = compiled_sdpa(query, key, value, upper_left_bias)
<class 'torch.nn.attention.bias.CausalBias'>
<class 'torch.nn.attention.bias.CausalBias'>
tensor([[ True, False, False, False, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False, False, False, False]])
tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]])

总结

在本教程中,我们演示了 torch.nn.functional.scaled_dot_product_attention 的基本用法。我们展示了如何使用 sdpa_kernel 上下文管理器来确保在 GPU 上使用特定的实现。此外,我们构建了一个简单的 CausalSelfAttention 模块,该模块与 NestedTensor 兼容并且可被 torch 编译。在此过程中,我们还展示了如何使用性能分析工具来探索用户定义模块的性能特征。

本页目录