• Tutorials >
  • 通过 torch.compiler.set_stance 动态控制编译
Shortcuts

通过 torch.compiler.set_stance 动态控制编译

作者: William Wen

torch.compiler.set_stancetorch.compiler 的一种 API,允许您在多次调用模型时更改 torch.compile 的行为,而无需对模型重新应用 torch.compile

本教程提供了一些关于如何使用 torch.compiler.set_stance 的示例。

前提条件

  • torch >= 2.6

描述

torch.compile.set_stance 可以作为装饰器、上下文管理器或普通函数使用,以更改 torch.compile 在不同模型调用中的行为。

在以下示例中,"force_eager" 状态会忽略所有 torch.compile 指令。

import torch


@torch.compile
def foo(x):
    if torch.compiler.is_compiling():
        # torch.compile is active
        return x + 1
    else:
        # torch.compile is not active
        return x - 1


inp = torch.zeros(3)

print(foo(inp))  # compiled, prints 1

装饰器用法示例

@torch.compiler.set_stance("force_eager")
def bar(x):
    # force disable the compiler
    return foo(x)


print(bar(inp))  # not compiled, prints -1

上下文管理器用法示例

with torch.compiler.set_stance("force_eager"):
    print(foo(inp))  # not compiled, prints -1

普通函数用法示例

torch.compiler.set_stance("force_eager")
print(foo(inp))  # not compiled, prints -1
torch.compiler.set_stance("default")

print(foo(inp))  # compiled, prints 1

torch.compile 状态只能在任何 torch.compile 区域**之外**进行更改。尝试违反此规则会导致错误。

@torch.compile
def baz(x):
    # error!
    with torch.compiler.set_stance("force_eager"):
        return x + 1


try:
    baz(inp)
except Exception as e:
    print(e)


@torch.compiler.set_stance("force_eager")
def inner(x):
    return x + 1


@torch.compile
def outer(x):
    # error!
    return inner(x)


try:
    outer(inp)
except Exception as e:
    print(e)
其他状态包括:
  • "default":默认状态,用于正常编译。

  • "eager_on_recompile":在需要重新编译时以急切模式运行代码。如果存在适合输入的缓存编译代码,仍然会使用该代码。

  • "fail_on_recompile":在重新编译函数时抛出错误。

查看 torch.compiler.set_stance文档页面 ,了解更多状态和选项。将来也可能会添加更多状态/选项。

示例

防止重新编译

某些模型不希望有任何重新编译。例如,您可能总是有相同形状的输入。由于重新编译可能代价昂贵,我们可能希望在尝试重新编译时抛出错误,以便检测并修复相关问题。可以使用 "fail_on_recompilation" 状态。

@torch.compile
def my_big_model(x):
    return torch.relu(x)


# first compilation
my_big_model(torch.randn(3))

with torch.compiler.set_stance("fail_on_recompile"):
    my_big_model(torch.randn(3))  # no recompilation - OK
    try:
        my_big_model(torch.randn(4))  # recompilation - error
    except Exception as e:
        print(e)

如果抛出错误过于中断流程,我们可以改用 "eager_on_recompile",这会使 torch.compile 回退到急切模式,而不是抛出错误。如果我们不期望重新编译频繁发生,但当确实需要时,更愿意接受急切运行的成本而非重新编译的成本,这种方式可能更有用。

@torch.compile
def my_huge_model(x):
    if torch.compiler.is_compiling():
        return x + 1
    else:
        return x - 1


# first compilation
print(my_huge_model(torch.zeros(3)))  # 1

with torch.compiler.set_stance("eager_on_recompile"):
    print(my_huge_model(torch.zeros(3)))  # 1
    print(my_huge_model(torch.zeros(4)))  # -1
    print(my_huge_model(torch.zeros(3)))  # 1

测量性能提升

torch.compiler.set_stance 可用于在不定义单独急切模型的情况下比较急切模式和编译模式的性能。

# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000


@torch.compile
def my_gigantic_model(x, y):
    x = x @ y
    x = x @ y
    x = x @ y
    return x


inps = torch.randn(5, 5), torch.randn(5, 5)

with torch.compiler.set_stance("force_eager"):
    print("eager:", timed(lambda: my_gigantic_model(*inps))[1])

# warmups
for _ in range(3):
    my_gigantic_model(*inps)

print("compiled:", timed(lambda: my_gigantic_model(*inps))[1])

尽早发现错误

通过在编译迭代之前先运行一个急切模式迭代(使用 "force_eager" 状态),可以帮助我们在尝试长时间编译之前捕捉与 torch.compile 无关的错误。

@torch.compile
def my_humongous_model(x):
    return torch.sin(x, x)


try:
    with torch.compiler.set_stance("force_eager"):
        print(my_humongous_model(torch.randn(3)))
    # this call to the compiled model won't run
    print(my_humongous_model(torch.randn(3)))
except Exception as e:
    print(e)

结论

在本教程中,我们学习了如何使用 torch.compiler.set_stance API 来在不需要重新应用 torch.compile 的情况下修改对模型的多次调用中的编译行为。教程展示了将 torch.compiler.set_stance 作为装饰器、上下文管理器或普通函数使用,以控制 force_eagerdefaulteager_on_recompilefail_on_recompile 等编译状态。

更多信息,参见:torch.compiler.set_stance API 文档

脚本的总运行时间: (0分钟 0.000秒)

由Sphinx-Gallery生成的图集

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源