.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "recipes/torch_compile_user_defined_triton_kernel_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_recipes_torch_compile_user_defined_triton_kernel_tutorial.py: Using User-Defined Triton Kernels with ``torch.compile`` ========================================================= **Author:** `Oguz Ulgen `_ .. GENERATED FROM PYTHON SOURCE LINES 10-32 User-defined Triton kernels can be used to optimize specific parts of your model's computation. These kernels are written in Triton's language, which is designed to make it easier to achieve peak hardware performance. By using user-defined Triton kernels with ``torch.compile``, you can integrate these optimized computations into your PyTorch model, potentially achieving significant performance improvements. This recipes demonstrates how you can use user-defined Triton kernels with ``torch.compile``. Prerequisites ------------------- Before starting this recipe, make sure that you have the following: * Basic understanding of ``torch.compile`` and Triton. See: * `torch.compiler API documentation `__ * `Introduction to torch.compile `__ * `Triton language documentation `__ * PyTorch 2.3 or later * A GPU that supports Triton .. GENERATED FROM PYTHON SOURCE LINES 32-36 .. code-block:: default import torch from torch.utils._triton import has_triton .. GENERATED FROM PYTHON SOURCE LINES 37-44 Basic Usage -------------------- In this example, we will use a simple vector addition kernel from the Triton documentation with ``torch.compile``. For reference, see `Triton documentation `__. .. GENERATED FROM PYTHON SOURCE LINES 44-81 .. code-block:: default if not has_triton(): print("Skipping because triton is not supported on this device.") else: import triton from triton import language as tl @triton.jit def add_kernel( in_ptr0, in_ptr1, out_ptr, n_elements, BLOCK_SIZE: "tl.constexpr", ): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements x = tl.load(in_ptr0 + offsets, mask=mask) y = tl.load(in_ptr1 + offsets, mask=mask) output = x + y tl.store(out_ptr + offsets, output, mask=mask) @torch.compile(fullgraph=True) def add_fn(x, y): output = torch.zeros_like(x) n_elements = output.numel() grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4) return output x = torch.randn(4, device="cuda") y = torch.randn(4, device="cuda") out = add_fn(x, y) print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}") .. GENERATED FROM PYTHON SOURCE LINES 82-96 Advanced Usage ------------------------------------------------------------------- Triton's autotune feature is a powerful tool that automatically optimizes the configuration parameters of your Triton kernels. It explores a range of possible configurations and selects the one that delivers the best performance for your specific use case. When used with ``torch.compile``, ``triton.autotune`` can help ensure that your PyTorch model is running as efficiently as possible. Here is an example of using ``torch.compile`` and ``triton.autotune``. .. note:: ``torch.compile`` only supports configs and key arguments to ``triton.autotune``. .. GENERATED FROM PYTHON SOURCE LINES 96-142 .. code-block:: default if not has_triton(): print("Skipping because triton is not supported on this device.") else: import triton from triton import language as tl @triton.autotune( configs=[ triton.Config({"BLOCK_SIZE": 4}, num_stages=3, num_warps=8), triton.Config({"BLOCK_SIZE": 4}, num_stages=4, num_warps=4), triton.Config({"BLOCK_SIZE": 2}, num_stages=3, num_warps=8), triton.Config({"BLOCK_SIZE": 2}, num_stages=4, num_warps=4), ], key=[], ) @triton.jit def add_kernel_autotuned( in_ptr0, in_ptr1, out_ptr, n_elements, BLOCK_SIZE: "tl.constexpr", ): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements x = tl.load(in_ptr0 + offsets, mask=mask) y = tl.load(in_ptr1 + offsets, mask=mask) output = x + y tl.store(out_ptr + offsets, output, mask=mask) @torch.compile(fullgraph=True) def add_fn(x, y): output = torch.zeros_like(x) n_elements = output.numel() grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) add_kernel_autotuned[grid](x, y, output, n_elements) return output x = torch.randn(4, device="cuda") y = torch.randn(4, device="cuda") out = add_fn(x, y) print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}") .. GENERATED FROM PYTHON SOURCE LINES 143-198 Composability ------------------------------------------------------------------- User-defined Triton kernels do not automatically support all PyTorch subsystems. This can be seen in the following use cases: - Adding a CPU fallback - Adding a ``FlopCounter`` formula - Composing with Tensor Subclasses To compose with additional PyTorch subsystems, use ``torch.library.triton_op``. ``triton_op is`` a structured way of defining a custom operator that is backed by one or more Triton kernels: like regular custom operators (``torch.library.custom_op``), you are able to specify the interactions with PyTorch subsystems via ``torch.library``. However, unlike ``torch.library.custom_op``, which creates opaque callables with respect to ``torch.compile``, ``torch.compile`` traces into ``triton_op`` to apply optimizations. Here’s a chart of which API to use when integrating Triton kernels with PyTorch. .. list-table:: :header-rows: 1 * - - Triton kernel (no explicit ``torch.library`` wrapper) - ``torch.library.triton_op`` - ``torch.library.custom_op`` * - Supports inference - Yes - Yes - Yes * - Supports training - In the majority of cases - Yes - Yes * - Supports ``torch.compile`` - Yes - Yes - Yes * - Supports ``torch.compile(fullgraph=True)`` - In the majority of cases - In the majority of cases - In all cases * - Does torch.compile trace into the implementation? - Yes - Yes - No * - Supports AOTInductor - Yes - Yes - No * - Supports PyTorch Subsystems like FlopCounterMode, CPU Fallback, Tensor Subclasses - No - Yes - Yes .. GENERATED FROM PYTHON SOURCE LINES 200-205 Wrapping Triton kernels with ``triton_op`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Use ``torch.library.triton_op`` to wrap a function that may invoke one or more Triton kernels. Use ``torch.library.wrap_triton`` to wrap the calls to the Triton kernel. .. GENERATED FROM PYTHON SOURCE LINES 205-230 .. code-block:: default from torch.library import triton_op, wrap_triton @triton_op("mylib::mysin", mutates_args={}) def mysin(x: torch.Tensor) -> torch.Tensor: out = torch.empty_like(x) n_elements = x.numel() wrap_triton(sin_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4) return out @triton.jit def sin_kernel( in_ptr0, out_ptr, n_elements, BLOCK_SIZE: "tl.constexpr", ): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements x = tl.load(in_ptr0 + offsets, mask=mask) output = tl.sin(x) tl.store(out_ptr + offsets, output, mask=mask) .. GENERATED FROM PYTHON SOURCE LINES 231-232 You can invoke the ``triton_op`` in one of the following two ways. .. GENERATED FROM PYTHON SOURCE LINES 232-240 .. code-block:: default x = torch.randn(3, device="cuda") y = mysin(x) z = torch.ops.mylib.mysin.default(x) assert torch.allclose(y, x.sin()) assert torch.allclose(z, x.sin()) .. GENERATED FROM PYTHON SOURCE LINES 241-242 The resulting ``triton_op`` works with ``torch.compile`` and ``AOTInductor``. .. GENERATED FROM PYTHON SOURCE LINES 242-246 .. code-block:: default y = torch.compile(mysin)(x) assert torch.allclose(y, x.sin()) .. GENERATED FROM PYTHON SOURCE LINES 247-253 Adding training support ^^^^^^^^^^^^^^^^^^^^^^^ Use ``register_autograd`` to add an autograd formula for the ``triton_op``. Prefer this to using ``torch.autograd.Function`` (which has various composability footguns with ``torch.compile``). .. GENERATED FROM PYTHON SOURCE LINES 253-264 .. code-block:: default def backward(ctx, grad): x, = ctx.saved_tensors return grad * x.cos() def setup_context(ctx, inputs, output): x, = inputs ctx.save_for_backward(x) mysin.register_autograd(backward, setup_context=setup_context) .. GENERATED FROM PYTHON SOURCE LINES 265-267 Note that the backward must be a composition of PyTorch-understood operators. If you want the backward to call Triton kernels, then those must be wrapped in ``triton_op`` as well: .. GENERATED FROM PYTHON SOURCE LINES 267-300 .. code-block:: default @triton.jit def cos_kernel( in_ptr0, out_ptr, n_elements, BLOCK_SIZE: "tl.constexpr", ): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements x = tl.load(in_ptr0 + offsets, mask=mask) output = tl.cos(x) tl.store(out_ptr + offsets, output, mask=mask) @triton_op("mylib::mycos", mutates_args={}) def mycos(x: torch.Tensor) -> torch.Tensor: out = torch.empty_like(x) n_elements = x.numel() wrap_triton(cos_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4) return out def backward(ctx, grad): x, = ctx.saved_tensors return grad * mycos(x) def setup_context(ctx, inputs, output): x, = inputs ctx.save_for_backward(x) mysin.register_autograd(backward, setup_context=setup_context) .. GENERATED FROM PYTHON SOURCE LINES 301-304 Adding a CPU Fallback ^^^^^^^^^^^^^^^^^^^^^ Triton kernels don’t run on CPU. Use ``register_kernel`` to add a CPU (or any other device) fallback for the ``triton_op``: .. GENERATED FROM PYTHON SOURCE LINES 304-313 .. code-block:: default @mysin.register_kernel("cpu") def _(x): return torch.sin(x) x = torch.randn(3) y = mysin(x) assert torch.allclose(y, x.sin()) .. GENERATED FROM PYTHON SOURCE LINES 314-315 The fallback must be composed of PyTorch operators. .. GENERATED FROM PYTHON SOURCE LINES 317-322 Adding a FlopCounter Formula ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ To specify how many flops the triton kernel reports under PyTorch's flop counter, use ``register_flop_formula``. .. GENERATED FROM PYTHON SOURCE LINES 322-334 .. code-block:: default from torch.utils.flop_counter import FlopCounterMode, register_flop_formula @register_flop_formula(torch.ops.mylib.mysin) def _(x_shape): numel = 1 for s in x_shape: numel *= s return numel x = torch.randn(3, device="cuda") .. GENERATED FROM PYTHON SOURCE LINES 335-341 ``FlopCounterMode`` requires `tabulate `__. Before running the code below, make sure you have ``tabulate`` installed or install by running ``pip install tabulate``. >>> with FlopCounterMode() as flop_counter: >>> y = mysin(x) .. GENERATED FROM PYTHON SOURCE LINES 343-373 Limitations -------------------------------------------------------------------- As of PyTorch 2.3, the support for user-defined Triton kernels in ``torch.compile`` includes dynamic shapes, ``torch.autograd.Function``, JIT inductor, and AOT inductor. You can use these features together to build complex, high-performance models. PyTorch 2.6 added ``torch.library.triton_op``, which adds support for user-defined Triton kernels in tensor subclasses and other advanced features. However, there are certain limitations to be aware of: * **Triton Features:** While ``triton.heuristics`` can be used either standalone or before ``triton.autotune``, it cannot be used after ``triton.autotune``. This implies that if ``triton.heuristics`` and ``triton.autotune`` are to be used together, ``triton.heuristics`` must be used first. Conclusion ----------- In this recipe, we explored how to utilize user-defined Triton kernels with ``torch.compile``. We delved into the basic usage of a simple vector addition kernel and advanced usage involving Triton's autotune feature. We also discussed the composability of user-defined Triton kernels with other PyTorch features and highlighted some current limitations. See Also --------- * `Compiling the Optimizers `__ * `Implementing High-Performance Transformers with Scaled Dot Product Attention `__ .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_recipes_torch_compile_user_defined_triton_kernel_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: torch_compile_user_defined_triton_kernel_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: torch_compile_user_defined_triton_kernel_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_