.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "recipes/torch_compile_torch_function_modes.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_recipes_torch_compile_torch_function_modes.py: (beta) Utilizing Torch Function modes with torch.compile ============================================================ **Author:** `Michael Lazos `_ .. GENERATED FROM PYTHON SOURCE LINES 9-16 This recipe covers how to use a key torch extensibility point, torch function modes, in tandem with ``torch.compile`` to override the behavior of torch operators, also know as **ops**, at trace time, with no runtime overhead. .. note:: This recipe requires PyTorch 2.7.0 or later. .. GENERATED FROM PYTHON SOURCE LINES 19-25 Rewriting a torch op (torch.add -> torch.mul) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ For this example, we'll use torch function modes to rewrite occurences of addition with multiply instead. This type of override can be common if a certain backend has a custom implementation that should be dispatched for a given op. .. GENERATED FROM PYTHON SOURCE LINES 25-69 .. code-block:: default import torch # exit cleanly if we are on a device that doesn't support ``torch.compile`` if torch.cuda.get_device_capability() < (7, 0): print("Exiting because torch.compile is not supported on this device.") import sys sys.exit(0) from torch.overrides import BaseTorchFunctionMode # Define our mode, Note: ``BaseTorchFunctionMode`` # implements the actual invocation of func(..) class AddToMultiplyMode(BaseTorchFunctionMode): def __torch_function__(self, func, types, args=(), kwargs=None): if func == torch.Tensor.add: func = torch.mul return super().__torch_function__(func, types, args, kwargs) @torch.compile() def test_fn(x, y): return x + y * x # Note: infix operators map to torch.Tensor.* methods x = torch.rand(2, 2) y = torch.rand_like(x) with AddToMultiplyMode(): z = test_fn(x, y) assert torch.allclose(z, x * y * x) # The mode can also be used within the compiled region as well like this: @torch.compile() def test_fn(x, y): with AddToMultiplyMode(): return x + y * x # Note: infix operators map to torch.Tensor.* methods x = torch.rand(2, 2) y = torch.rand_like(x) z = test_fn(x, y) assert torch.allclose(z, x * y * x) .. GENERATED FROM PYTHON SOURCE LINES 70-78 Conclusion ~~~~~~~~~~ In this recipe we demonstrated how to override the behavior of ``torch.*`` operators using torch function modes from within ``torch.compile``. This enables users to utilize the extensibility benefits of torch function modes without the runtime overhead of calling torch function on every op invocation. * See `Extending Torch API with Modes `__ for other examples and background on Torch Function modes. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_recipes_torch_compile_torch_function_modes.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: torch_compile_torch_function_modes.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: torch_compile_torch_function_modes.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_