Shortcuts

torch.compile 简介

Created On: Mar 15, 2023 | Last Updated: Apr 28, 2025 | Last Verified: Nov 05, 2024

作者: William Wen

torch.compile 是一种最新方法,可以加速您的PyTorch代码!torch.compile 通过JIT编译PyTorch代码为优化内核,从而让PyTorch代码运行得更快,同时需要的代码改动最少。

在本教程中,我们介绍 torch.compile 的基础用法,并演示其相对于之前PyTorch编译器解决方案(如`TorchScript <https://pytorch.org/docs/stable/jit.html>`__ 和 FX Tracing)的优势。

目录

所需的pip依赖

  • torch >= 2.0

  • torchvision

  • numpy

  • scipy

  • tabulate

系统要求 - 一个C++编译器,例如 g++ - Python开发包 (python-devel/python-dev)

注意:推荐使用现代NVIDIA GPU(H100、A100或V100)以再现以下所示和其他地方记录的加速数值。

import torch
import warnings

gpu_ok = False
if torch.cuda.is_available():
    device_cap = torch.cuda.get_device_capability()
    if device_cap in ((7, 0), (8, 0), (9, 0)):
        gpu_ok = True

if not gpu_ok:
    warnings.warn(
        "GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower "
        "than expected."
    )

基本用法

torch.compile 包含在最新的PyTorch中。在GPU上运行TorchInductor需要Triton,Triton包含在PyTorch 2.0夜间版中。如果仍缺少Triton,请尝试通过pip安装``torchtriton``(pip install torchtriton --extra-index-url "https://download.pytorch.org/whl/nightly/cu117" 适用于CUDA 11.7)。

可以通过将可调用对象传递给``torch.compile``来优化任意Python函数。然后,我们可以用返回的优化函数代替原始函数进行调用。

def foo(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b
opt_foo1 = torch.compile(foo)
print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10)))

或者,我们也可以装饰该函数。

t1 = torch.randn(10, 10)
t2 = torch.randn(10, 10)

@torch.compile
def opt_foo2(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b
print(opt_foo2(t1, t2))

我们也可以优化``torch.nn.Module``实例。

t = torch.randn(10, 100)

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(100, 10)

    def forward(self, x):
        return torch.nn.functional.relu(self.lin(x))

mod = MyModule()
mod.compile()
print(mod(t))
## or:
# opt_mod = torch.compile(mod)
# print(opt_mod(t))

torch.compile 和嵌套调用

装饰函数中的嵌套函数调用也会被编译。

def nested_function(x):
    return torch.sin(x)

@torch.compile
def outer_function(x, y):
    a = nested_function(x)
    b = torch.cos(y)
    return a + b

print(outer_function(t1, t2))

同样,当编译一个模块时,所有非跳过列表中的子模块和方法也都会被编译。

class OuterModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.inner_module = MyModule()
        self.outer_lin = torch.nn.Linear(10, 2)

    def forward(self, x):
        x = self.inner_module(x)
        return torch.nn.functional.relu(self.outer_lin(x))

outer_mod = OuterModule()
outer_mod.compile()
print(outer_mod(t))

我们还可以使用``torch.compiler.disable``禁用某些函数的编译。假设您想仅对``complex_function``函数禁用跟踪,但希望在``complex_conjugate``中恢复跟踪。在这种情况下,您可以使用 torch.compiler.disable(recursive=False) 选项。否则,默认是``recursive=True``。

def complex_conjugate(z):
    return torch.conj(z)

@torch.compiler.disable(recursive=False)
def complex_function(real, imag):
    # Assuming this function cause problems in the compilation
    z = torch.complex(real, imag)
    return complex_conjugate(z)

def outer_function():
    real = torch.tensor([2, 3], dtype=torch.float32)
    imag = torch.tensor([4, 5], dtype=torch.float32)
    z = complex_function(real, imag)
    return torch.abs(z)

# Try to compile the outer_function
try:
    opt_outer_function = torch.compile(outer_function)
    print(opt_outer_function())
except Exception as e:
    print("Compilation of outer_function failed:", e)

最佳实践与建议

``torch.compile``在嵌套模块和函数调用中的行为

使用``torch.compile``时,编译器会尝试递归编译目标函数或模块内的每个函数调用,前提是这些函数或模块不在跳过列表中(例如内置函数,一些torch.*命名空间中的函数)。

最佳实践:

1. Top-Level Compilation: One approach is to compile at the highest level possible (i.e., when the top-level module is initialized/called) and selectively disable compilation when encountering excessive graph breaks or errors. If there are still many compile issues, compile individual subcomponents instead.

2. Modular Testing: Test individual functions and modules with torch.compile before integrating them into larger models to isolate potential issues.

3. Disable Compilation Selectively: If certain functions or sub-modules cannot be handled by torch.compile, use the torch.compiler.disable context managers to recursively exclude them from compilation.

4. Compile Leaf Functions First: In complex models with multiple nested functions and modules, start by compiling the leaf functions or modules first. For more information see TorchDynamo APIs for fine-grained tracing.

  1. 推荐使用 ``mod.compile()`` 而不是 ``torch.compile(mod)``: 避免``state_dict``中的``_orig_``前缀问题。

6. Use ``fullgraph=True`` to catch graph breaks: Helps ensure end-to-end compilation, maximizing speedup and compatibility with torch.export.

演示加速效果

现在让我们演示使用``torch.compile``可以加速真实模型。我们将在随机数据上评估和训练一个``torchvision``模型,并比较其标准模式(eager)和``torch.compile``模式。

在开始前,我们需要定义一些实用函数。

# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000

# Generates random input and targets data for the model, where `b` is
# batch size.
def generate_data(b):
    return (
        torch.randn(b, 3, 128, 128).to(torch.float32).cuda(),
        torch.randint(1000, (b,)).cuda(),
    )

N_ITERS = 10

from torchvision.models import densenet121
def init_model():
    return densenet121().to(torch.float32).cuda()

首先,让我们比较推论。

请注意在调用``torch.compile``时,我们还有一个附加的``mode``参数,我们将随后讨论。

model = init_model()

# Reset since we are using a different mode.
import torch._dynamo
torch._dynamo.reset()

model_opt = torch.compile(model, mode="reduce-overhead")

inp = generate_data(16)[0]
with torch.no_grad():
    print("eager:", timed(lambda: model(inp))[1])
    print("compile:", timed(lambda: model_opt(inp))[1])

注意``torch.compile``完成的时间比标准模式(eager)长得多,因为``torch.compile``在执行时编译模型为优化内核。在我们的示例中,模型的结构不会变化,因此不需要重新编译。所以如果我们多次运行已经优化的模型,应该能看到与标准模式相比显著的改进。

eager_times = []
for i in range(N_ITERS):
    inp = generate_data(16)[0]
    with torch.no_grad():
        _, eager_time = timed(lambda: model(inp))
    eager_times.append(eager_time)
    print(f"eager eval time {i}: {eager_time}")

print("~" * 10)

compile_times = []
for i in range(N_ITERS):
    inp = generate_data(16)[0]
    with torch.no_grad():
        _, compile_time = timed(lambda: model_opt(inp))
    compile_times.append(compile_time)
    print(f"compile eval time {i}: {compile_time}")
print("~" * 10)

import numpy as np
eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert(speedup > 1)
print(f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
print("~" * 10)

确实,我们可以看到运行我们的模型时,``torch.compile``相对于标准模式具有显著的加速效果。加速主要是通过减少Python的开销和GPU的读写开销实现的,因此观察到的加速可能会因模型架构和批次大小等因素而有所不同。例如,如果模型架构相对简单并且数据量较大,则瓶颈可能是GPU计算,观察到的加速效果可能相对较小。

您可能还会发现基于所选择的``mode``参数会有不同的加速结果。``”reduce-overhead”``模式使用CUDA图进一步减少Python的开销。对于您自己的模型,可能需要尝试不同的模式以获得最大加速效果。您可以在`此处 <https://pytorch.org/get-started/pytorch-2.0/#user-experience>`__阅读更多有关模式的信息。

您可能还会注意到,我们第二次运行模型时比其他运行慢得多,尽管比第一次运行快得多。这是因为``”reduce-overhead”``模式会为CUDA图运行一些热身迭代。

对于一般的PyTorch基准测试,可以尝试使用``torch.utils.benchmark``,而不是我们定义的``timed``函数。我们在本教程中编写自己的计时函数是为了展示``torch.compile``的编译延迟。

现在,让我们考虑比较训练。

model = init_model()
opt = torch.optim.Adam(model.parameters())

def train(mod, data):
    opt.zero_grad(True)
    pred = mod(data[0])
    loss = torch.nn.CrossEntropyLoss()(pred, data[1])
    loss.backward()
    opt.step()

eager_times = []
for i in range(N_ITERS):
    inp = generate_data(16)
    _, eager_time = timed(lambda: train(model, inp))
    eager_times.append(eager_time)
    print(f"eager train time {i}: {eager_time}")
print("~" * 10)

model = init_model()
opt = torch.optim.Adam(model.parameters())
train_opt = torch.compile(train, mode="reduce-overhead")

compile_times = []
for i in range(N_ITERS):
    inp = generate_data(16)
    _, compile_time = timed(lambda: train_opt(model, inp))
    compile_times.append(compile_time)
    print(f"compile train time {i}: {compile_time}")
print("~" * 10)

eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert(speedup > 1)
print(f"(train) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
print("~" * 10)

同样,我们可以看到,``torch.compile``在第一次迭代中花费了更长时间,因为必须编译模型,但在后续迭代中,与标准模式相比,我们看到了显著的加速效果。

我们提醒,本教程中展示的加速数字仅供演示。官方加速数据可在`TorchInductor性能仪表盘 <https://hud.pytorch.org/benchmark/compilers>`__中查看。

与TorchScript和FX Tracing的比较

我们已经看到``torch.compile``可以加速PyTorch代码。那么,为什么还要选择``torch.compile``而不是现有的PyTorch编译器解决方案,例如TorchScript或FX Tracing?主要原因是``torch.compile``能够以最小的代码改动处理任意Python代码。

一个``torch.compile``可以处理但其他编译器解决方案难以处理的情况是数据依赖性的控制流(如下所示的``if x.sum() < 0:``行)。

def f1(x, y):
    if x.sum() < 0:
        return -y
    return y

# Test that `fn1` and `fn2` return the same result, given
# the same arguments `args`. Typically, `fn1` will be an eager function
# while `fn2` will be a compiled function (torch.compile, TorchScript, or FX graph).
def test_fns(fn1, fn2, args):
    out1 = fn1(*args)
    out2 = fn2(*args)
    return torch.allclose(out1, out2)

inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)

对函数``f1``的TorchScript跟踪会导致无声的错误结果,因为仅实际的控制流路径被跟踪。

traced_f1 = torch.jit.trace(f1, (inp1, inp2))
print("traced 1, 1:", test_fns(f1, traced_f1, (inp1, inp2)))
print("traced 1, 2:", test_fns(f1, traced_f1, (-inp1, inp2)))

对函数``f1``的FX跟踪会因为存在数据依赖性的控制流而导致错误。

import traceback as tb
try:
    torch.fx.symbolic_trace(f1)
except:
    tb.print_exc()

如果我们为``x``提供一个值并尝试对``f1``进行FX跟踪,则会遇到与TorchScript跟踪同样的问题,因为数据依赖性的控制流在跟踪函数中被移除了。

fx_f1 = torch.fx.symbolic_trace(f1, concrete_args={"x": inp1})
print("fx 1, 1:", test_fns(f1, fx_f1, (inp1, inp2)))
print("fx 1, 2:", test_fns(f1, fx_f1, (-inp1, inp2)))

现在我们可以看到``torch.compile``正确地处理了与数据有关的控制流。

# Reset since we are using a different mode.
torch._dynamo.reset()

compile_f1 = torch.compile(f1)
print("compile 1, 1:", test_fns(f1, compile_f1, (inp1, inp2)))
print("compile 1, 2:", test_fns(f1, compile_f1, (-inp1, inp2)))
print("~" * 10)

TorchScript脚本可以处理与数据有关的控制流,但是这种解决方案也有其自身的问题。即,TorchScript脚本可能需要进行大量代码更改,并且在使用不支持的Python时会引发错误。

在下面的示例中,我们忘记了TorchScript类型注解,结果因为参数``y``的输入类型是``int``,而不是默认的``torch.Tensor``类型而收到TorchScript错误。

def f2(x, y):
    return x + y

inp1 = torch.randn(5, 5)
inp2 = 3

script_f2 = torch.jit.script(f2)
try:
    script_f2(inp1, inp2)
except:
    tb.print_exc()

然而,torch.compile``能够轻松地处理``f2

compile_f2 = torch.compile(f2)
print("compile 2:", test_fns(f2, compile_f2, (inp1, inp2)))
print("~" * 10)

另一个``torch.compile``相较于之前的编译器解决方案处理得很好的情况是对非PyTorch函数的使用。

import scipy
def f3(x):
    x = x * 2
    x = scipy.fft.dct(x.numpy())
    x = torch.from_numpy(x)
    x = x * 2
    return x

TorchScript跟踪将非PyTorch函数调用的结果视为常量,因此我们的结果可能会默默地出错。

inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)
traced_f3 = torch.jit.trace(f3, (inp1,))
print("traced 3:", test_fns(f3, traced_f3, (inp2,)))

TorchScript脚本和FX跟踪不允许调用非PyTorch函数。

try:
    torch.jit.script(f3)
except:
    tb.print_exc()

try:
    torch.fx.symbolic_trace(f3)
except:
    tb.print_exc()

相比之下,``torch.compile``能够轻松地处理非PyTorch函数调用。

compile_f3 = torch.compile(f3)
print("compile 3:", test_fns(f3, compile_f3, (inp2,)))

TorchDynamo和FX图

``torch.compile``的一个重要组件是TorchDynamo。TorchDynamo负责将任意Python代码JIT编译为FX图,从而进行进一步优化。TorchDynamo通过在运行时分析Python字节码并检测对PyTorch操作的调用来提取FX图。

通常情况下,``torch.compile``的另一个组件TorchInductor会进一步将FX图编译为优化的内核,但TorchDynamo允许使用不同的后端。为了检查TorchDynamo输出的FX图,我们创建了一个自定义后端,该后端输出FX图并简单地返回图的未优化前向方法。

from typing import List
def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("custom backend called with FX graph:")
    gm.graph.print_tabular()
    return gm.forward

# Reset since we are using a different backend.
torch._dynamo.reset()

opt_model = torch.compile(init_model(), backend=custom_backend)
opt_model(generate_data(16)[0])

使用我们的自定义后端,现在我们可以看到TorchDynamo如何处理与数据有关的控制流。考虑下面的函数,其中``if b.sum() < 0``一行是数据相关控制流的来源。

def bar(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b

opt_bar = torch.compile(bar, backend=custom_backend)
inp1 = torch.randn(10)
inp2 = torch.randn(10)
opt_bar(inp1, inp2)
opt_bar(inp1, -inp2)

输出显示TorchDynamo提取了与以下代码对应的3个不同的FX图(顺序可能与上面的输出不同):

  1. x = a / (torch.abs(a) + 1)

  2. b = b * -1; return x * b

  3. return x * b

当TorchDynamo遇到不支持的Python特性(例如数据相关的控制流)时,它会中断计算图,让默认的Python解释器处理不支持的代码,然后恢复捕获图。

让我们通过一个示例调查TorchDynamo如何逐步处理``bar``。如果``b.sum() < 0``,TorchDynamo会运行图1,让Python确定条件的结果,然后运行图2。另一方面,如果``not b.sum() < 0``,TorchDynamo会运行图1,让Python确定条件的结果,然后运行图3。

这突出体现了TorchDynamo与之前PyTorch编译器解决方案的主要区别。当遇到不支持的Python特性时,之前的解决方案要么引发错误,要么静默失败。而TorchDynamo则会中断计算图。

我们可以通过使用``torch._dynamo.explain``来查看TorchDynamo中断图的地方:

# Reset since we are using a different backend.
torch._dynamo.reset()
explain_output = torch._dynamo.explain(bar)(torch.randn(10), torch.randn(10))
print(explain_output)

为了最大限度地提高速度,应尽量限制图中断。我们可以通过使用``fullgraph=True``强制TorchDynamo在遇到第一个图中断时引发错误:

opt_bar = torch.compile(bar, fullgraph=True)
try:
    opt_bar(torch.randn(10), torch.randn(10))
except:
    tb.print_exc()

下面,我们演示了TorchDynamo在我们上面用于演示加速的模型上并未中断图。

opt_model = torch.compile(init_model(), fullgraph=True)
print(opt_model(generate_data(16)[0]))

我们可以使用``torch.export``(从PyTorch 2.1+开始支持)从输入的PyTorch程序中提取单个、可导出的FX图。导出的图旨在在不同的(即无需Python的)环境中运行。一项重要的限制是``torch.export``不支持图中断。有关``torch.export``的更多详细信息,请查看`本教程 <https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html>`__。

结论

在本教程中,我们通过涵盖基本用法、展示相较于即时模式的加速效果、与之前的PyTorch编译器解决方案进行比较,及简要研究TorchDynamo与FX图的交互,介绍了``torch.compile``。我们希望您能尝试一下``torch.compile``!

**脚本的总运行时间:**(0分钟0.000秒)

通过Sphinx-Gallery生成的图集

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源