• Tutorials >
  • (测试版)结合 torch.compile 使用 Torch Function modes
Shortcuts

(测试版)结合 torch.compile 使用 Torch Function modes

作者Michael Lazos

本教程介绍如何在 torch.compile 中使用重要的 PyTorch 扩展点,即 torch Function modes,以便在不增加运行时开销的情况下,在跟踪时覆写 torch 操作符(也称为 ops)的行为。

此教程需要 PyTorch 2.7.0 或更高版本。

备注

重写 torch 操作符(torch.add -> torch.mul)

在本示例中,我们将使用 torch Function modes 将加法的出现改写为乘法。这种类型的覆写很常见,如果某个特定的后端对某个操作符有自定义实现,就应该为其分派此实现。

For this example, we’ll use torch function modes to rewrite occurences of addition with multiply instead. This type of override can be common if a certain backend has a custom implementation that should be dispatched for a given op.

import torch

# exit cleanly if we are on a device that doesn't support ``torch.compile``
if torch.cuda.get_device_capability() < (7, 0):
    print("Exiting because torch.compile is not supported on this device.")
    import sys
    sys.exit(0)

from torch.overrides import BaseTorchFunctionMode

# Define our mode, Note: ``BaseTorchFunctionMode``
# implements the actual invocation of func(..)
class AddToMultiplyMode(BaseTorchFunctionMode):
    def __torch_function__(self, func, types, args=(), kwargs=None):
        if func == torch.Tensor.add:
            func = torch.mul

        return super().__torch_function__(func, types, args, kwargs)

@torch.compile()
def test_fn(x, y):
    return x + y * x # Note: infix operators map to torch.Tensor.* methods

x = torch.rand(2, 2)
y = torch.rand_like(x)

with AddToMultiplyMode():
    z = test_fn(x, y)

assert torch.allclose(z, x * y * x)

# The mode can also be used within the compiled region as well like this:

@torch.compile()
def test_fn(x, y):
    with AddToMultiplyMode():
        return x + y * x # Note: infix operators map to torch.Tensor.* methods

x = torch.rand(2, 2)
y = torch.rand_like(x)
z = test_fn(x, y)

assert torch.allclose(z, x * y * x)

结论

在这个教程中,我们演示了如何在 torch.compile 中使用 torch 函数模式覆盖 torch.* 操作的行为。这使用户能够利用 torch 函数模式的扩展性优势,而无需在每次操作调用时承受调用 torch 函数的运行时开销。

脚本的总运行时间: (0分钟 0.000秒)

由Sphinx-Gallery生成的图集

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源