备注
点击 这里 下载完整示例代码
(测试版)结合 torch.compile 使用 Torch Function modes¶
- 本教程介绍如何在
torch.compile
中使用重要的 PyTorch 扩展点,即 torch Function modes,以便在不增加运行时开销的情况下,在跟踪时覆写 torch 操作符(也称为 ops)的行为。 此教程需要 PyTorch 2.7.0 或更高版本。
备注
重写 torch 操作符(torch.add -> torch.mul)
在本示例中,我们将使用 torch Function modes 将加法的出现改写为乘法。这种类型的覆写很常见,如果某个特定的后端对某个操作符有自定义实现,就应该为其分派此实现。¶
For this example, we’ll use torch function modes to rewrite occurences of addition with multiply instead. This type of override can be common if a certain backend has a custom implementation that should be dispatched for a given op.
import torch
# exit cleanly if we are on a device that doesn't support ``torch.compile``
if torch.cuda.get_device_capability() < (7, 0):
print("Exiting because torch.compile is not supported on this device.")
import sys
sys.exit(0)
from torch.overrides import BaseTorchFunctionMode
# Define our mode, Note: ``BaseTorchFunctionMode``
# implements the actual invocation of func(..)
class AddToMultiplyMode(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
if func == torch.Tensor.add:
func = torch.mul
return super().__torch_function__(func, types, args, kwargs)
@torch.compile()
def test_fn(x, y):
return x + y * x # Note: infix operators map to torch.Tensor.* methods
x = torch.rand(2, 2)
y = torch.rand_like(x)
with AddToMultiplyMode():
z = test_fn(x, y)
assert torch.allclose(z, x * y * x)
# The mode can also be used within the compiled region as well like this:
@torch.compile()
def test_fn(x, y):
with AddToMultiplyMode():
return x + y * x # Note: infix operators map to torch.Tensor.* methods
x = torch.rand(2, 2)
y = torch.rand_like(x)
z = test_fn(x, y)
assert torch.allclose(z, x * y * x)
结论¶
在这个教程中,我们演示了如何在 torch.compile
中使用 torch 函数模式覆盖 torch.*
操作的行为。这使用户能够利用 torch 函数模式的扩展性优势,而无需在每次操作调用时承受调用 torch 函数的运行时开销。
参见 使用模式扩展 Torch API ,了解其他示例和关于 Torch 函数模式的背景知识。
脚本的总运行时间: (0分钟 0.000秒)