备注
点击:ref:`here <sphx_glr_download_intermediate_scaled_dot_product_attention_tutorial.py>`下载完整示例代码。
(Beta)实现基于缩放点积注意力(SDPA)的高性能Transformer¶
Created On: Mar 15, 2023 | Last Updated: Oct 09, 2024 | Last Verified: Nov 05, 2024
作者: Driss Guessous
摘要¶
在本教程中,我们希望突出一个有助于实现Transformer架构的新``torch.nn.functional``函数。该函数名为``torch.nn.functional.scaled_dot_product_attention``。关于此函数的详细描述,请参阅`PyTorch文档 <https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention>`__。该函数已被集成到``torch.nn.MultiheadAttention``和``torch.nn.TransformerEncoderLayer``中。
概述¶
从高层次来看,此PyTorch函数根据论文`Attention is all you need <https://arxiv.org/abs/1706.03762>`__中的定义,计算查询(Query)、键(Key)和值(Value)之间的缩放点积注意力(SDPA)。尽管可以使用现有函数在PyTorch中编写此功能,但融合实现相比于简单实现能够提供显著的性能提升。
融合实现¶
对于CUDA张量输入,该函数将调度到以下实现之一:
备注
本教程需要PyTorch 2.0.0或更高版本。
import torch
import torch.nn as nn
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"
# Example Usage:
query, key, value = torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device)
F.scaled_dot_product_attention(query, key, value)
显式调度器控制¶
尽管该函数会自动调度到三种实现之一,用户也可以通过上下文管理器显式控制调度。这个上下文管理器允许用户明确禁用某些实现。如果用户想确保某种特定输入使用最快的实现,可以通过上下文管理器进行性能测试。
# Lets define a helpful benchmarking function:
import torch.utils.benchmark as benchmark
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6
# Lets define the hyper-parameters of our input
batch_size = 32
max_sequence_len = 1024
num_heads = 32
embed_dimension = 32
dtype = torch.float16
query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
# Lets explore the speed of each of the 3 implementations
from torch.nn.attention import SDPBackend, sdpa_kernel
with sdpa_kernel(SDPBackend.MATH):
math_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
print(f"The math implementation runs in {math_time:.3f} microseconds")
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
try:
flash_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
print(f"The flash attention implementation runs in {flash_time:.3f} microseconds")
except RuntimeError:
print("FlashAttention is not supported. See warnings for reasons.")
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
try:
efficient_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
print(f"The memory efficient implementation runs in {efficient_time:.3f} microseconds")
except RuntimeError:
print("EfficientAttention is not supported. See warnings for reasons.")
硬件依赖性¶
根据运行代码的机器和可用的硬件,你的运行结果可能不同。- 如果你没有GPU并在CPU上运行,那么以FP32运行时上下文管理器无效,三次运行时间应该相似。- 根据你的显卡支持的计算能力,闪存注意力或内存高效注意力可能会失败。
因果自注意力¶
以下是一个多头因果自注意力块的实现示例,灵感来自`Andrej Karpathy NanoGPT <https://github.com/karpathy/nanoGPT>`__库。
class CausalSelfAttention(nn.Module):
def __init__(self, num_heads: int, embed_dimension: int, bias: bool=False, is_causal: bool=False, dropout:float=0.0):
super().__init__()
assert embed_dimension % num_heads == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(embed_dimension, 3 * embed_dimension, bias=bias)
# output projection
self.c_proj = nn.Linear(embed_dimension, embed_dimension, bias=bias)
# regularization
self.dropout = dropout
self.resid_dropout = nn.Dropout(dropout)
self.num_heads = num_heads
self.embed_dimension = embed_dimension
# Perform causal masking
self.is_causal = is_causal
def forward(self, x):
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
query_projected = self.c_attn(x)
batch_size = query_projected.size(0)
embed_dim = query_projected.size(2)
head_dim = embed_dim // (self.num_heads * 3)
query, key, value = query_projected.chunk(3, -1)
query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
if self.training:
dropout = self.dropout
is_causal = self.is_causal
else:
dropout = 0.0
is_causal = False
y = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=dropout, is_causal=is_causal)
y = y.transpose(1, 2).view(batch_size, -1, self.num_heads * head_dim)
y = self.resid_dropout(self.c_proj(y))
return y
num_heads = 8
heads_per_dim = 64
embed_dimension = num_heads * heads_per_dim
dtype = torch.float16
model = CausalSelfAttention(num_heads=num_heads, embed_dimension=embed_dimension, bias=False, is_causal=True, dropout=0.1).to("cuda").to(dtype).eval()
print(model)
``NestedTensor``和密集张量的支持¶
SDPA支持``NestedTensor``和密集张量输入。``NestedTensors``处理输入是可变长度序列的批量,而无需为每个序列填充到批次的最大长度。有关``NestedTensors``的更多信息,请参阅`torch.nested <https://pytorch.org/docs/stable/nested.html>`__和`NestedTensors教程 <https://pytorch.org/tutorials/prototype/nestedtensor.html>`__。
import random
def generate_rand_batch(
batch_size,
max_sequence_len,
embed_dimension,
pad_percentage=None,
dtype=torch.float16,
device="cuda",
):
if not pad_percentage:
return (
torch.randn(
batch_size,
max_sequence_len,
embed_dimension,
dtype=dtype,
device=device,
),
None,
)
# Random sequence lengths
seq_len_list = [
int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01)))
for _ in range(batch_size)
]
# Make random entry in the batch have max sequence length
seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len
return (
torch.nested.nested_tensor(
[
torch.randn(seq_len, embed_dimension,
dtype=dtype, device=device)
for seq_len in seq_len_list
]
),
seq_len_list,
)
random_nt, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=0.5, dtype=dtype, device=device)
random_dense, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=None, dtype=dtype, device=device)
# Currently the fused implementations don't support ``NestedTensor`` for training
model.eval()
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
try:
print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds")
print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds")
except RuntimeError:
print("FlashAttention is not supported. See warnings for reasons.")
将SDPA与``torch.compile``结合使用¶
随着PyTorch 2.0的发布,引入了一个新功能“torch.compile()”,该功能可以显著提升急切模式的性能。缩放点积注意力完全可以与“torch.compile()”组合使用。为了演示这一点,我们将使用“torch.compile()”编译“CausalSelfAttention”模块,并观察由此带来的性能提升。
batch_size = 32
max_sequence_len = 256
x = torch.rand(batch_size, max_sequence_len,
embed_dimension, device=device, dtype=dtype)
print(
f"The non compiled module runs in {benchmark_torch_function_in_microseconds(model, x):.3f} microseconds")
compiled_model = torch.compile(model)
# Let's compile it
compiled_model(x)
print(
f"The compiled module runs in {benchmark_torch_function_in_microseconds(compiled_model, x):.3f} microseconds")
具体的执行时间取决于机器,然而在我的机器上的结果如下:未编译模块运行时间为166.616微秒,编译模块运行时间为166.726微秒。这不是我们预期的结果。让我们深入一下。PyTorch内置了一个非常出色的性能分析工具,可以用来检查代码的性能特性。
from torch.profiler import profile, record_function, ProfilerActivity
activities = [ProfilerActivity.CPU]
if device == 'cuda':
activities.append(ProfilerActivity.CUDA)
with profile(activities=activities, record_shapes=False) as prof:
with record_function(" Non-Compilied Causal Attention"):
for _ in range(25):
model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
with profile(activities=activities, record_shapes=False) as prof:
with record_function("Compiled Causal Attention"):
for _ in range(25):
compiled_model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
# For even more insights, you can export the trace and use ``chrome://tracing`` to view the results
#
# .. code-block:: python
#
# prof.export_chrome_trace("compiled_causal_attention_trace.json").
前面的代码片段生成了一个报告,列出了消耗GPU执行时间最多的前10个PyTorch函数,适用于编译和未编译的模块。这项分析显示,大部分GPU时间都集中在两种模块中同一组函数的执行上。这里的原因是“torch.compile”非常擅长去除与PyTorch相关的框架开销。如果您的模型正在启动大型、效率高的CUDA内核(如在此例中的“CausalSelfAttention”),那么PyTorch的开销可以被隐藏。
实际上,您的模块通常不会仅仅由一个单一的“CausalSelfAttention”块组成。在尝试“Andrej Karpathy NanoGPT <https://github.com/karpathy/nanoGPT>`__存储库时,编译模块将每次训练步骤时间从“6090.49毫秒”减少到“3273.17毫秒”!这是在NanoGPT使用Shakespeare数据集进行训练的提交版本“ae3a8d5”下完成的。
使用带有attn_bias子类的SDPA¶
# As of PyTorch 2.3, we have added a new submodule that contains tensor subclasses.
# Designed to be used with ``torch.nn.functional.scaled_dot_product_attention``.
# The module is named ``torch.nn.attention.bias`` and contains the following two
# utilities for generating causal attention variants:
#
# - ``torch.nn.attention.bias.causal_upper_left``
# - ``torch.nn.attention.bias.causal_lower_right``
#
# .. note::
# The current argument ``is_causal`` in ``torch.nn.functional.scaled_dot_product_attention``
# is the same as using ``torch.nn.attention.bias.causal_upper_left``.
#
from torch.nn.attention.bias import causal_lower_right, causal_upper_left
batch_size = 32
sequence_length_q = 2
sequence_length_kv = 10
num_heads = 16
embed_dimension = 32
dtype = torch.float16
query = torch.rand(batch_size, num_heads, sequence_length_q, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)
upper_left_bias = causal_upper_left(sequence_length_q, sequence_length_kv)
lower_right_bias = causal_lower_right(sequence_length_q, sequence_length_kv)
print(type(upper_left_bias))
print(type(lower_right_bias))
assert type(upper_left_bias) == type(lower_right_bias)
assert issubclass(type(upper_left_bias), torch.Tensor)
# As you can see from the previous output, are the same type ``torch.nn.attention.bias.CausalBias``
# and subclass ``torch.Tensor``
# Lets see what these tensors look like
print(upper_left_bias)
print(lower_right_bias)
# Upper Left Bias aligns the causal attention mask to the upper left corner of the attention scores matrix.
# This only has an impact when the attention scores matrix is not square, which is common for decoding use cases.
# Another way of thinking about this concept is that when you use upper left bias,
# the 0th token in the query is aligned to the 0th token in the key, while for lower right bias,
# Assuming the attention score matrix is two dimensional, ``attn_score[0][0]`` is the attention score
# between the 0th token in the query and the 0th token in the key.
# For lower right bias, the sequence of q is aligned so that the last token in q is aligned to the last token in k
# (for example, ``attn_score[-1][-1])`` is all True since the last token in q is at the same position as the last token in k
# even if the sequence length of q and k are different.
# These objects are intended to be used with sdpa
out_upper_left = F.scaled_dot_product_attention(query, key, value, upper_left_bias)
out_lower_right = F.scaled_dot_product_attention(query, key, value, lower_right_bias)
out_is_causal = F.scaled_dot_product_attention(query, key, value, is_causal=True)
assert torch.allclose(out_upper_left, out_is_causal)
assert not torch.allclose(out_upper_left, out_lower_right)
# These attention biases should also be compatible with torch.compile
compiled_sdpa = torch.compile(F.scaled_dot_product_attention, fullgraph=True)
out_upper_left = compiled_sdpa(query, key, value, upper_left_bias)
结论¶
在本教程中,我们演示了“torch.nn.functional.scaled_dot_product_attention”的基本用法。我们展示了如何使用“sdpa_kernel”上下文管理器在GPU上断言某个具体的实现被使用。同时,我们构建了一个简单的“CausalSelfAttention”模块,能够与“NestedTensor”一起工作并且可以使用torch编译。在此过程中,我们展示了如何使用性能分析工具来探索用户定义模块的性能特性。