.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "intermediate/transformer_building_blocks.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_intermediate_transformer_building_blocks.py: .. meta:: :description: Learn how to optimize transformer models by replacing nn.Transformer with Nested Tensors and torch.compile() for significant performance gains in PyTorch. Accelerating PyTorch Transformers by replacing ``nn.Transformer`` with Nested Tensors and ``torch.compile()`` ============================================================================================================= **Author:** `Mikayla Gawarecki `_ .. grid:: 2 .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn :class-card: card-prerequisites * Learn about the low-level building blocks PyTorch provides to build custom transformer layers ( nested tensors, ``scaled_dot_product_attention``, ``torch.compile()``, and ``FlexAttention``) * Discover how the above improve memory usage and performance using MultiHeadAttention as an example * Explore advanced customizations using the aforementioned building blocks .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites :class-card: card-prerequisites * PyTorch v.2.6.0 or later Over the past few years, the PyTorch team has developed various lower level features that, when composed, can create a variety of transformer variants. These include: * Nested Tensors with the ``torch.jagged`` layout (AKA NJTs) * ``scaled_dot_product_attention`` * ``torch.compile()`` * ``FlexAttention`` This tutorial will give a brief overview of the above technologies and demonstrate how they can be composed to yield flexible and performant transformer layers with improved user experience. One may observe that the ``torch.nn`` module currently provides various ``Transformer``-related layers. In particular, it includes ``TransformerEncoderLayer``, ``TransformerEncoder``, ``TransformerDecoderLayer``, ``TransformerDecoder``, ``Transformer`` and ``MultiheadAttention``. This family of layers was initially implemented following the `Attention is All You Need `_ paper. The components discussed in this tutorial provide improved user experience, flexibility and performance over the existing ``nn`` layers. Is this tutorial for me? ======================== If you are wondering about what building blocks the ``torch`` library provides for writing your own transformer layers and best practices, you are in the right place. Please keep reading! If you are looking for an out-of-the-box implementation of a popular transformer architecture, note that there are many open-source libraries that provide them, including: * `HuggingFace transformers `_ * `xformers `_ * `torchtune `_ If you are only interested in performant attention score modifications, please check out the `FlexAttention blog `_ that contains a `gym of masks `_. .. GENERATED FROM PYTHON SOURCE LINES 70-128 Introducing the Building Blocks =============================== First, we will briefly introduce the four technologies mentioned in the introduction * `torch.nested `_ Nested tensors generalize the shape of regular dense tensors, allowing for representation of ragged-sized data with the same tensor UX. In the context of transformers, we can think of nested tensors as a tool for representing variable sequence lengths. They eliminate the need for the bug-prone practices of explicit padding and masking (think ``key_padding_mask`` in ``nn.MultiHeadAttention``). * `scaled_dot_product_attention `_ ``scaled_dot_product_attention`` is a primitive for :math:`\text{softmax}(\frac{QK^T}{\sqrt{E}} + B)V` that dispatches into either fused implementations of the operator or a fallback implementation. It works out of the box in eager mode (i.e. the default mode of using PyTorch where operations are executed on the fly as they are encountered) and also integrates seamlessly with ``torch.compile()``. As of 2.6, it will also offer grouped query attention natively. * `torch.compile() `_ ``torch.compile()`` is a compiler introduced in version 2.0 that is able to capture a graph of PyTorch code and perform various optimizations on it, such as fusing together sequences of ops. Nested tensors with the ``torch.jagged`` layout and ``scaled_dot_product_attention`` work seamlessly with compile. In the context of transformers, the value add of using compile with nested tensor and SDPA is that compile can remove framework overhead ones sees in eager mode and fuse sequences of ops in transformers together, such as projection and activation. * `FlexAttention `_ ``FlexAttention`` is a primitive that allows users to modify attention scores prior to the softmax operation. It generalizes the additive ``B`` term above for ``scaled_dot_product_attention``, allowing for arbitrary calculation. It requires compile to achieve good performance. The above building blocks are "All You Need" (as of October 2024) ================================================================== The main premise in this section is that most transformer variations are GPT-style, consisting of layers like Embedding, Positional Encoding, Attention Blocks and Feed Forward networks. If we were to try to classify the differences in this space, we might land on something like: 1. Layer type (activation functions such as ``SwiGLU`` and others, normalization functions such as ``RMSNorm`` and others, positional encodings, such as Sinusoidal, Rotary.) 2. Layer ordering, such as where to apply norms and positional encoding. 3. Modifications to attention score, such as ``ALiBi``, Relative Positional Bias and so on. In a pre-compiler environment, you might write a custom transformer and notice that it functions correctly but is slow. To address this, you might develop a custom fused kernel for the specific series of operations. In a compiler environment, you can simply perform the initial step and then compile and benefit from improved performance. .. GENERATED FROM PYTHON SOURCE LINES 131-162 MultiheadAttention ------------------ Remember that MultiheadAttention takes in a query, key, and value, and consists of an input projection, a ``scaled_dot_product_attention`` operator and an output projection. The main takeaway we want to demonstrate here is the improvement yielded when we replaced padded/masked inputs with nested tensors. The improvements are threefold: * **User Experience** Remember that ``nn.MultiheadAttention`` requires ``query``, ``key``, and ``value`` to be dense ``torch.Tensors``. It also provides a ``key_padding_mask`` that is used to mask out padding tokens in the ``key`` that arise due to different sequence lengths within a batch. Since there is no ``query_padding_mask`` in ``nn.MHA``, users have to take care to mask/slice the outputs appropriately to account for query sequence lengths. ``NestedTensor`` cleanly removes the need for this sort of error-prone padding masks. * **Memory** Instead of materializing a dense ``[B, S, D]`` tensor with a ``[B, S]`` padding mask (where ``B`` is batch size, ``S`` is max sequence length in the batch and ``D`` is embedding size), nested tensors allow you to cleanly represent the batch of varying sequence lengths. As a result, the input and intermediate activations will use less memory. * **Performance** Since padding is not materialized and unnecessary computation on padding is skipped, performance and memory usage improve. We'll demonstrate the above by building upon the ``MultiheadAttention`` layer in the `Nested Tensor tutorial `_ and comparing it to the ``nn.MultiheadAttention`` layer. .. GENERATED FROM PYTHON SOURCE LINES 162-287 .. code-block:: default import torch import torch.nn as nn import torch.nn.functional as F class MultiHeadAttention(nn.Module): """ Computes multi-head attention. Supports nested or padded tensors. Args: E_q (int): Size of embedding dim for query E_k (int): Size of embedding dim for key E_v (int): Size of embedding dim for value E_total (int): Total embedding dim of combined heads post input projection. Each head has dim E_total // nheads nheads (int): Number of heads dropout (float, optional): Dropout probability. Default: 0.0 bias (bool, optional): Whether to add bias to input projection. Default: True """ def __init__( self, E_q: int, E_k: int, E_v: int, E_total: int, nheads: int, dropout: float = 0.0, bias=True, device=None, dtype=None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.nheads = nheads self.dropout = dropout self._qkv_same_embed_dim = E_q == E_k and E_q == E_v if self._qkv_same_embed_dim: self.packed_proj = nn.Linear(E_q, E_total * 3, bias=bias, **factory_kwargs) else: self.q_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs) self.k_proj = nn.Linear(E_k, E_total, bias=bias, **factory_kwargs) self.v_proj = nn.Linear(E_v, E_total, bias=bias, **factory_kwargs) E_out = E_q self.out_proj = nn.Linear(E_total, E_out, bias=bias, **factory_kwargs) assert E_total % nheads == 0, "Embedding dim is not divisible by nheads" self.E_head = E_total // nheads self.bias = bias def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask=None, is_causal=False, ) -> torch.Tensor: """ Forward pass; runs the following process: 1. Apply input projection 2. Split heads and prepare for SDPA 3. Run SDPA 4. Apply output projection Args: query (torch.Tensor): query of shape (``N``, ``L_q``, ``E_qk``) key (torch.Tensor): key of shape (``N``, ``L_kv``, ``E_qk``) value (torch.Tensor): value of shape (``N``, ``L_kv``, ``E_v``) attn_mask (torch.Tensor, optional): attention mask of shape (``N``, ``L_q``, ``L_kv``) to pass to SDPA. Default: None is_causal (bool, optional): Whether to apply causal mask. Default: False Returns: attn_output (torch.Tensor): output of shape (N, L_t, E_q) """ # Step 1. Apply input projection if self._qkv_same_embed_dim: if query is key and key is value: result = self.packed_proj(query) query, key, value = torch.chunk(result, 3, dim=-1) else: q_weight, k_weight, v_weight = torch.chunk( self.packed_proj.weight, 3, dim=0 ) if self.bias: q_bias, k_bias, v_bias = torch.chunk( self.packed_proj.bias, 3, dim=0 ) else: q_bias, k_bias, v_bias = None, None, None query, key, value = ( F.linear(query, q_weight, q_bias), F.linear(key, k_weight, k_bias), F.linear(value, v_weight, v_bias), ) else: query = self.q_proj(query) key = self.k_proj(key) value = self.v_proj(value) # Step 2. Split heads and prepare for SDPA # reshape query, key, value to separate by head # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head) query = query.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2) # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) key = key.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2) # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) value = value.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2) # Step 3. Run SDPA # (N, nheads, L_t, E_head) attn_output = F.scaled_dot_product_attention( query, key, value, dropout_p=self.dropout, is_causal=is_causal ) # (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total) attn_output = attn_output.transpose(1, 2).flatten(-2) # Step 4. Apply output projection # (N, L_t, E_total) -> (N, L_t, E_out) attn_output = self.out_proj(attn_output) return attn_output .. GENERATED FROM PYTHON SOURCE LINES 288-293 Utilities ~~~~~~~~~ In this section, we include a utility to generate semi-realistic data using ``Zipf`` distribution for sentence lengths. This is used to generate the nested query, key, and value tensors. We also include a benchmark utility. .. GENERATED FROM PYTHON SOURCE LINES 293-367 .. code-block:: default import numpy as np def zipf_sentence_lengths(alpha: float, batch_size: int) -> torch.Tensor: # generate fake corpus by unigram Zipf distribution # from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858 sentence_lengths = np.empty(batch_size, dtype=int) for ibatch in range(batch_size): sentence_lengths[ibatch] = 1 word = np.random.zipf(alpha) while word != 3 and word != 386 and word != 858: sentence_lengths[ibatch] += 1 word = np.random.zipf(alpha) return torch.tensor(sentence_lengths) # Generate a batch of semi-realistic data using Zipf distribution for sentence lengths # in the form of nested tensors with the jagged layout. def gen_batch(N, E_q, E_k, E_v, device, dtype=torch.float32, query_seq_len_1=False): # generate semi-realistic data using Zipf distribution for sentence lengths sentence_lengths = zipf_sentence_lengths(alpha=1.2, batch_size=N) # Note: the torch.jagged layout is a nested tensor layout that supports a single ragged # dimension and works with torch.compile. The batch items each have shape (B, S*, D) # where B = batch size, S* = ragged sequence length, and D = embedding dimension. if query_seq_len_1: query = torch.nested.nested_tensor( [torch.randn(1, E_q, dtype=dtype, device=device) for l in sentence_lengths], layout=torch.jagged, ) else: query = torch.nested.nested_tensor( [ torch.randn(l.item(), E_q, dtype=dtype, device=device) for l in sentence_lengths ], layout=torch.jagged, ) key = torch.nested.nested_tensor( [ torch.randn(s.item(), E_k, dtype=dtype, device=device) for s in sentence_lengths ], layout=torch.jagged, ) value = torch.nested.nested_tensor( [ torch.randn(s.item(), E_v, dtype=dtype, device=device) for s in sentence_lengths ], layout=torch.jagged, ) return query, key, value, sentence_lengths import math import timeit def benchmark(func, *args, **kwargs): torch.cuda.synchronize() torch.cuda.reset_peak_memory_stats() begin = timeit.default_timer() output = func(*args, **kwargs) torch.cuda.synchronize() end = timeit.default_timer() return output, (end - begin), torch.cuda.max_memory_allocated() .. GENERATED FROM PYTHON SOURCE LINES 368-371 We will now demonstrate the performance improvements of using nested tensors in the ``MultiheadAttention`` layer + compile for self attention. We compare this against the traditional ``nn.MultiheadAttention`` + compile with padding and masking. .. GENERATED FROM PYTHON SOURCE LINES 371-463 .. code-block:: default N, E_q, E_k, E_v, E_total = 512, 512, 512, 512, 512 E_out = E_q d_model = E_q nheads = 8 dropout = 0.0 bias = True device = "cuda" torch.manual_seed(6) query, key, value, sentence_lengths = gen_batch(N, E_q, E_k, E_v, device) S = sentence_lengths.max().item() print( f"Total sequence length in nested query {sentence_lengths.sum().item()}, max sequence length {S}" ) padded_query, padded_key, padded_value = ( t.to_padded_tensor(0.0) for t in (query, key, value) ) torch.manual_seed(6) mha_layer = MultiHeadAttention( E_q, E_k, E_v, E_total, nheads, dropout=dropout, bias=bias, device="cuda" ) torch.manual_seed(6) vanilla_mha_layer = nn.MultiheadAttention( E_q, nheads, dropout=dropout, batch_first=True, bias=bias, device="cuda" ) # ``nn.MultiheadAttention`` uses a non conventional initialization for layers, so do this for exact parity :( mha_layer.out_proj.weight = nn.Parameter( vanilla_mha_layer.out_proj.weight.clone().detach() ) mha_layer.packed_proj.weight = nn.Parameter( vanilla_mha_layer.in_proj_weight.clone().detach() ) mha_layer.out_proj.bias = nn.Parameter(vanilla_mha_layer.out_proj.bias.clone().detach()) mha_layer.packed_proj.bias = nn.Parameter( vanilla_mha_layer.in_proj_bias.clone().detach() ) new_mha_layer = torch.compile(mha_layer) # warmup compile nested_result_warmup = new_mha_layer(query, query, query, is_causal=True) # benchmark nested_result, nested_time, nested_peak_memory = benchmark( new_mha_layer, query, query, query, is_causal=True ) padded_nested_result = nested_result.to_padded_tensor(0.0) # For the vanilla ``nn.MultiheadAttention``, we need to construct the ``key_padding_mask`` # Further, ``nn.MultiheadAttention`` forces one to materialize the ``attn_mask`` even if using ``is_causal`` src_key_padding_mask = torch.where(padded_query == 0.0, -math.inf, 0)[:, :, 0] attn_mask = torch.empty((N, S, S), device=device).fill_(float("-inf")) for i, s in enumerate(sentence_lengths): attn_mask[i, :s, :s] = nn.Transformer.generate_square_subsequent_mask(s) attn_mask = attn_mask.unsqueeze(1).expand(N, nheads, S, S).reshape(N * nheads, S, S) vanilla_mha_layer = torch.compile(vanilla_mha_layer) # warmup compile warmup_vanilla_result = vanilla_mha_layer( padded_query, padded_query, padded_query, attn_mask=attn_mask, key_padding_mask=src_key_padding_mask, need_weights=False, is_causal=True, ) # benchmark (padded_result, _), padded_time, padded_peak_memory = benchmark( vanilla_mha_layer, padded_query, padded_query, padded_query, key_padding_mask=src_key_padding_mask, need_weights=False, attn_mask=attn_mask, is_causal=True, ) print(f"{padded_time=:.5f}, padded_peak_memory={padded_peak_memory/1e9:.2f} GB") print(f"{nested_time=:.5f}, nested_peak_memory={nested_peak_memory/1e9:.2f} GB") print( "Max difference between vanilla and nested result", (padded_result - padded_nested_result).abs().max().item(), ) print(f"Nested speedup: {(padded_time/nested_time):.2f}") print( f"Nested peak memory reduction {((padded_peak_memory - nested_peak_memory)/1e9):.2f} GB" ) .. GENERATED FROM PYTHON SOURCE LINES 464-475 For reference, here are some sample outputs on A100: .. code:: padded_time=0.03454, padded_peak_memory=4.14 GB nested_time=0.00612, nested_peak_memory=0.76 GB Max difference between vanilla and nested result 0.0 Nested speedup: 5.65 Nested peak memory reduction 3.39 GB We can also see the same for backward pass .. GENERATED FROM PYTHON SOURCE LINES 475-523 .. code-block:: default for i, entry_length in enumerate(sentence_lengths): # padding-specific step: remove output projection bias from padded entries for fair comparison padded_result[i, entry_length:, :] = 0.0 _, padded_bw_time, padded_bw_peak_mem = benchmark( lambda: padded_result.sum().backward() ) _, nested_bw_time, nested_bw_peak_mem = benchmark( lambda: padded_nested_result.sum().backward() ) print(f"{padded_bw_time=:.5f}, padded_bw_peak_mem={padded_bw_peak_mem/1e9:.2f} GB") print(f"{nested_bw_time=:.5f}, nested_bw_peak_mem={nested_bw_peak_mem/1e9:.2f} GB") print(f"Nested backward speedup: {(padded_bw_time/nested_bw_time):.2f}") print( f"Nested backward peak memory reduction {((padded_bw_peak_mem - nested_bw_peak_mem)/1e9):.2f} GB" ) print( "Difference in out_proj.weight.grad", (mha_layer.out_proj.weight.grad - vanilla_mha_layer.out_proj.weight.grad) .abs() .max() .item(), ) print( "Difference in packed_proj.weight.grad", (mha_layer.packed_proj.weight.grad - vanilla_mha_layer.in_proj_weight.grad) .abs() .max() .item(), ) print( "Difference in out_proj.bias.grad", (mha_layer.out_proj.bias.grad - vanilla_mha_layer.out_proj.bias.grad) .abs() .max() .item(), ) print( "Difference in packed_proj.bias.grad", (mha_layer.packed_proj.bias.grad - vanilla_mha_layer.in_proj_bias.grad) .abs() .max() .item(), ) .. GENERATED FROM PYTHON SOURCE LINES 524-537 Sample outputs on A100: .. code:: padded_bw_time=2.09337, padded_bw_peak_mem=5.10 GB nested_bw_time=0.01452, nested_bw_peak_mem=3.24 GB Nested backward speedup: 144.13 Nested backward peak memory reduction 1.86 GB Difference in out_proj.weight.grad 0.000244140625 Difference in packed_proj.weight.grad 0.001556396484375 Difference in out_proj.bias.grad 0.0 Difference in packed_proj.bias.grad 0.001953125 .. GENERATED FROM PYTHON SOURCE LINES 539-550 GPT-style layer --------------- A basic GPT-style transformer layer consists of a causal self-attention layer followed by a feed-forward network (FFN) with skip connections. Implementing this is fairly straightforward using the ``MultiheadAttention`` layer above and gives equivalent results to an ``nn.TransformerEncoderLayer`` with ``is_causal=True``. We demonstrate examples of implementing the rest of the ``nn`` layers `here `_ but omit that from this tutorial for brevity. .. GENERATED FROM PYTHON SOURCE LINES 553-569 Going one step further ---------------------- So far, we have demonstrated how to implement a performant ``MultiheadAttention`` layer that follows the traditional ``nn.MultiheadAttention``. Going back to our classification of modifications to the transformer architecture, remember that we classified the modifications into layer type, layer ordering, and modifications to the attention score. We trust that changing layer type and layer ordering (such as swapping ``LayerNorm`` for ``RMSNorm``) is fairly straightforward. In this section, we will discuss various functionalities using the aforementioned building blocks, including the following: * Cross Attention * Fully masked rows no longer cause NaNs * Modifying attention score: ALiBi with FlexAttention and NJT * Packed Projection .. GENERATED FROM PYTHON SOURCE LINES 571-581 Cross Attention --------------- Cross attention is a form of attention where the query and key/value tensors are from different sequences. One example of this is in ``nn.TransformerDecoderLayer`` where the query comes from the decoder and the key/value come from the encoder. The above MultiheadAttention layer nicely generalizes to this case with nested tensors for both query and key/value. .. GENERATED FROM PYTHON SOURCE LINES 581-593 .. code-block:: default query, _, _, q_len = gen_batch(N, E_q, E_k, E_v, device) _, key, value, kv_len = gen_batch(N, E_q, E_k, E_v, device) print( f"Total sequence length in nested query {q_len.sum().item()}, max sequence length {q_len.max().item()}" ) print( f"Total sequence length in nested key/value {kv_len.sum().item()}, max sequence length {kv_len.max().item()}" ) out = new_mha_layer(query, key, value, is_causal=False) .. GENERATED FROM PYTHON SOURCE LINES 594-595 As above, we can compare this against the vanilla compiled ``nn.MultiheadAttention``. .. GENERATED FROM PYTHON SOURCE LINES 595-642 .. code-block:: default torch.manual_seed(6) query, _, _, q_len = gen_batch(N, E_q, E_k, E_v, device) _, key, value, kv_len = gen_batch(N, E_q, E_k, E_v, device) padded_query, padded_key, padded_value = ( t.to_padded_tensor(0.0) for t in (query, key, value) ) key_padding_mask = torch.where(padded_key == 0.0, -math.inf, 0)[:, :, 0] # warmup compile warmup_nested_result = new_mha_layer(query, key, value, is_causal=False) warmup_vanilla_result = vanilla_mha_layer( padded_query, padded_key, padded_value, key_padding_mask=key_padding_mask, need_weights=False, is_causal=False, ) nested_result, nested_time, nested_peak_memory = benchmark( new_mha_layer, query, key, value, is_causal=False ) (padded_result, _), padded_time, padded_peak_memory = benchmark( vanilla_mha_layer, padded_query, padded_key, padded_value, key_padding_mask=key_padding_mask, need_weights=False, is_causal=False, ) padded_nested_result = nested_result.to_padded_tensor(0.0) for i, entry_length in enumerate(q_len): # padding-specific step: remove output projection bias from padded entries for fair comparison padded_result[i, entry_length:, :] = 0.0 print( "Max difference between vanilla and nested result", (padded_result - padded_nested_result).abs().max().item(), ) print(f"Nested speedup: {(padded_time/nested_time):.2f}") print( f"Nested peak memory reduction {((padded_peak_memory - nested_peak_memory)/1e9):.2f} GB" ) .. GENERATED FROM PYTHON SOURCE LINES 643-651 Sample outputs on A100: .. code:: Max difference between vanilla and nested result 0.0 Nested speedup: 4.01 Nested peak memory reduction 1.40 GB .. GENERATED FROM PYTHON SOURCE LINES 653-669 Fully masked rows no longer cause NaNs -------------------------------------- There has been a long standing issue with ``nn.MultiheadAttention`` and ``scaled_dot_product_attention`` where if a row was fully masked out, the output of the attention layer would be NaN. See `issue `_. This is because the softmax over an empty set is undefined. Thanks to `this PR `_ this is no longer the case. Instead, the output corresponding to fully masked rows in ``scaled_dot_product_attention`` will be 0. For cases where ``nn.MHA`` does not employ the "fast-path", this will also apply. Using a custom MHA layer with NJTs is strongly recommended over the existing "fast-path" in ``nn.MultiheadAttention`` as NJT's ability to model raggedness appropriately makes it possible to properly express empty sequences. .. GENERATED FROM PYTHON SOURCE LINES 672-680 FlexAttention + NJT --------------------------------------------------------------------- NJT also composes with the ``FlexAttention`` module. This is a generalization of the ``MultiheadAttention`` layer that allows for arbitrary modifications to the attention score. The example below takes the ``alibi_mod`` that implements `ALiBi `_ from `attention gym `_ and uses it with nested input tensors. .. GENERATED FROM PYTHON SOURCE LINES 680-708 .. code-block:: default from torch.nn.attention.flex_attention import flex_attention def generate_alibi_bias(H: int): """Returns an alibi bias score_mod given the number of heads H Args: H: number of heads Returns: alibi_bias: alibi bias score_mod """ def alibi_mod(score, b, h, q_idx, kv_idx): scale = torch.exp2(-((h + 1) * 8.0 / H)) bias = (q_idx - kv_idx) * scale return score + bias return alibi_mod query, key, value, _ = gen_batch(N, E_q, E_k, E_v, device) n_heads, D = 8, E_q // 8 alibi_score_mod = generate_alibi_bias(n_heads) query = query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() value = value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() out_flex2 = flex_attention(query, key, value, score_mod=alibi_score_mod) .. GENERATED FROM PYTHON SOURCE LINES 709-716 In addition, one can also use the ``block_mask`` utility of ``FlexAttention`` with NJTs via the ``create_nested_block_mask`` function. This is useful for taking advantage of the sparsity of the mask to speed up the attention computation. In particular, the function creates a sparse block mask for a "stacked sequence" of all the variable length sequences in the NJT combined into one, while properly masking out inter-sequence attention. In the following example, we show how to create a causal block mask using this utility. .. GENERATED FROM PYTHON SOURCE LINES 716-731 .. code-block:: default from torch.nn.attention.flex_attention import create_nested_block_mask def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx query, key, value, _ = gen_batch(N, E_q, E_k, E_v, device) block_mask = create_nested_block_mask(causal_mask, 1, 1, query, _compile=True) query = query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() value = value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() out_flex = flex_attention(query, key, value, block_mask=block_mask) .. GENERATED FROM PYTHON SOURCE LINES 732-752 Packed Projection ----------------- Packed projection is a technique that makes use of the fact that when the input for projection (matrix multiplications) are the same (self-attention), we can pack the projection weights and biases into single tensors. It is especially useful when the individual projections are memory bound rather than compute bound. There are two examples that we will demonstrate here: * Input projection for MultiheadAttention * SwiGLU activation in feed-forward network of Transformer Layer Input projection for MultiheadAttention ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ When doing self-attention, the ``query``, ``key``, and ``value`` are the same tensor. Each of these tensors is projected with a ``Linear(E_q, E_total)`` layer. Instead, we can pack this into one layer, which is what we do in the MultiheadAttention layer above. Let us compare the performance of the packed projection against the usual method: .. GENERATED FROM PYTHON SOURCE LINES 752-798 .. code-block:: default class InputProjection(nn.Module): def __init__(self, E_q, E_total, bias=False, device=None, dtype=None): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.q_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs) self.k_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs) self.v_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs) def forward(self, x): return self.q_proj(x), self.k_proj(x), self.v_proj(x) class PackedInputProjection(nn.Module): def __init__(self, E_q, E_total, bias=False, device=None, dtype=None): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.packed_proj = nn.Linear(E_q, E_total * 3, bias=bias, **factory_kwargs) def forward(self, query): return torch.chunk(self.packed_proj(query), 3, dim=-1) B, D, dtype = 256, 8192, torch.bfloat16 torch.set_float32_matmul_precision("high") in_proj = torch.compile(InputProjection(D, D, device="cuda", dtype=torch.bfloat16)) packed_in_proj = torch.compile( PackedInputProjection(D, D, device="cuda", dtype=torch.bfloat16) ) q, _, _, sequence_lengths = gen_batch(B, D, D, D, device="cuda", dtype=torch.bfloat16) # warmup in_proj(q) packed_in_proj(q) # benchmark (q_out, k_out, v_out), time, _ = benchmark(in_proj, q) (q_out, k_out, v_out), time_packed, _ = benchmark(packed_in_proj, q) # On my A100 prints 1.05x speedup print( f"InputProjection: {time:5f} s, PackedInputProjection: {time_packed:5f} s, speedup: {time/time_packed:.2f}x" ) .. GENERATED FROM PYTHON SOURCE LINES 799-803 SwiGLU feed forward network of Transformer Layer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Swish-Gated Linear Unit (SwiGLU) is a non-linear activation function that is increasingly popular in the feed-forward network of the transformer layer (e.g. Llama). A feed-forward network with SwiGLU activation is defined as: .. GENERATED FROM PYTHON SOURCE LINES 803-831 .. code-block:: default class SwiGLUFFN(nn.Module): def __init__( self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=None, device=None, dtype=None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() hidden_dim = int(2 * hidden_dim / 3) # custom dim factor multiplier if ffn_dim_multiplier is not None: hidden_dim = int(ffn_dim_multiplier * hidden_dim) hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) self.w1 = nn.Linear(dim, hidden_dim, bias=False, **factory_kwargs) self.w2 = nn.Linear(hidden_dim, dim, bias=False, **factory_kwargs) self.w3 = nn.Linear(dim, hidden_dim, bias=False, **factory_kwargs) def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) .. GENERATED FROM PYTHON SOURCE LINES 832-833 An alternative way of implementing this that uses packed projection is .. GENERATED FROM PYTHON SOURCE LINES 833-861 .. code-block:: default class PackedSwiGLUFFN(nn.Module): def __init__( self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=None, device=None, dtype=None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() hidden_dim = int(2 * hidden_dim / 3) # custom dim factor multiplier if ffn_dim_multiplier is not None: hidden_dim = int(ffn_dim_multiplier * hidden_dim) hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) self.w13 = nn.Linear(dim, 2 * hidden_dim, bias=False, **factory_kwargs) self.w2 = nn.Linear(hidden_dim, dim, bias=False, **factory_kwargs) def forward(self, x): x1, x3 = torch.chunk(self.w13(x), 2, dim=-1) return self.w2(F.silu(x1) * x3) .. GENERATED FROM PYTHON SOURCE LINES 862-865 We can compare the performance of the two implementations as follows Depending on your hardware, you might see different results. On an A100 I see 1.12x speedup for D=128. .. GENERATED FROM PYTHON SOURCE LINES 865-886 .. code-block:: default D = 128 swigluffn = torch.compile(SwiGLUFFN(D, D * 4, 256, device="cuda", dtype=torch.bfloat16)) packed_swigluffn = torch.compile( PackedSwiGLUFFN(D, D * 4, 256, device="cuda", dtype=torch.bfloat16) ) q, _, _, sentence_lengths = gen_batch(D, D, D, D, device="cuda", dtype=torch.bfloat16) # warmup swigluffn(q) packed_swigluffn(q) # benchmark _, time, _ = benchmark(swigluffn, q) _, time_packed, _ = benchmark(packed_swigluffn, q) # On my A100 prints 1.08x speedup print( f"SwiGLUFFN: {time} s, PackedSwiGLUFFN: {time_packed} s, speedup: {time/time_packed:.2f}x" ) .. GENERATED FROM PYTHON SOURCE LINES 887-899 Extended examples ----------------- We intend to update this tutorial to demonstrate more examples of how to use the various performant building blocks such as KV-Caching, Grouped Query Attention etc. Further, there are several good examples of using various performant building blocks to implement various transformer architectures. Some examples include * `gpt-fast `_ * `segment-anything-fast `_ * `lucidrains implementation of NaViT with nested tensors `_ * `torchtune's implementation of VisionTransformer `_ .. GENERATED FROM PYTHON SOURCE LINES 901-908 Conclusion ---------- In this tutorial, we have introduced the low level building blocks PyTorch provides for writing transformer layers and demonstrated examples how to compose them. It is our hope that this tutorial has educated the reader on the ease with which flexible and performant transformer layers can be implemented by users of PyTorch. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_intermediate_transformer_building_blocks.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: transformer_building_blocks.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: transformer_building_blocks.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_