备注
点击 此处 下载完整示例代码
通过 torch.compiler.set_stance
动态控制编译¶
作者: William Wen
torch.compiler.set_stance
是 torch.compiler
的一种 API,允许您在多次调用模型时更改 torch.compile
的行为,而无需对模型重新应用 torch.compile
。
本教程提供了一些关于如何使用 torch.compiler.set_stance
的示例。
描述¶
torch.compile.set_stance
可以作为装饰器、上下文管理器或普通函数使用,以更改 torch.compile
在不同模型调用中的行为。
在以下示例中,"force_eager"
状态会忽略所有 torch.compile
指令。
import torch
@torch.compile
def foo(x):
if torch.compiler.is_compiling():
# torch.compile is active
return x + 1
else:
# torch.compile is not active
return x - 1
inp = torch.zeros(3)
print(foo(inp)) # compiled, prints 1
装饰器用法示例
@torch.compiler.set_stance("force_eager")
def bar(x):
# force disable the compiler
return foo(x)
print(bar(inp)) # not compiled, prints -1
上下文管理器用法示例
with torch.compiler.set_stance("force_eager"):
print(foo(inp)) # not compiled, prints -1
普通函数用法示例
torch.compiler.set_stance("force_eager")
print(foo(inp)) # not compiled, prints -1
torch.compiler.set_stance("default")
print(foo(inp)) # compiled, prints 1
torch.compile
状态只能在任何 torch.compile
区域**之外**进行更改。尝试违反此规则会导致错误。
@torch.compile
def baz(x):
# error!
with torch.compiler.set_stance("force_eager"):
return x + 1
try:
baz(inp)
except Exception as e:
print(e)
@torch.compiler.set_stance("force_eager")
def inner(x):
return x + 1
@torch.compile
def outer(x):
# error!
return inner(x)
try:
outer(inp)
except Exception as e:
print(e)
- 其他状态包括:
"default"
:默认状态,用于正常编译。"eager_on_recompile"
:在需要重新编译时以急切模式运行代码。如果存在适合输入的缓存编译代码,仍然会使用该代码。"fail_on_recompile"
:在重新编译函数时抛出错误。
查看 torch.compiler.set_stance
的 文档页面 ,了解更多状态和选项。将来也可能会添加更多状态/选项。
示例¶
防止重新编译¶
某些模型不希望有任何重新编译。例如,您可能总是有相同形状的输入。由于重新编译可能代价昂贵,我们可能希望在尝试重新编译时抛出错误,以便检测并修复相关问题。可以使用 "fail_on_recompilation"
状态。
@torch.compile
def my_big_model(x):
return torch.relu(x)
# first compilation
my_big_model(torch.randn(3))
with torch.compiler.set_stance("fail_on_recompile"):
my_big_model(torch.randn(3)) # no recompilation - OK
try:
my_big_model(torch.randn(4)) # recompilation - error
except Exception as e:
print(e)
如果抛出错误过于中断流程,我们可以改用 "eager_on_recompile"
,这会使 torch.compile
回退到急切模式,而不是抛出错误。如果我们不期望重新编译频繁发生,但当确实需要时,更愿意接受急切运行的成本而非重新编译的成本,这种方式可能更有用。
@torch.compile
def my_huge_model(x):
if torch.compiler.is_compiling():
return x + 1
else:
return x - 1
# first compilation
print(my_huge_model(torch.zeros(3))) # 1
with torch.compiler.set_stance("eager_on_recompile"):
print(my_huge_model(torch.zeros(3))) # 1
print(my_huge_model(torch.zeros(4))) # -1
print(my_huge_model(torch.zeros(3))) # 1
测量性能提升¶
torch.compiler.set_stance
可用于在不定义单独急切模型的情况下比较急切模式和编译模式的性能。
# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = fn()
end.record()
torch.cuda.synchronize()
return result, start.elapsed_time(end) / 1000
@torch.compile
def my_gigantic_model(x, y):
x = x @ y
x = x @ y
x = x @ y
return x
inps = torch.randn(5, 5), torch.randn(5, 5)
with torch.compiler.set_stance("force_eager"):
print("eager:", timed(lambda: my_gigantic_model(*inps))[1])
# warmups
for _ in range(3):
my_gigantic_model(*inps)
print("compiled:", timed(lambda: my_gigantic_model(*inps))[1])
尽早发现错误¶
通过在编译迭代之前先运行一个急切模式迭代(使用 "force_eager"
状态),可以帮助我们在尝试长时间编译之前捕捉与 torch.compile
无关的错误。
@torch.compile
def my_humongous_model(x):
return torch.sin(x, x)
try:
with torch.compiler.set_stance("force_eager"):
print(my_humongous_model(torch.randn(3)))
# this call to the compiled model won't run
print(my_humongous_model(torch.randn(3)))
except Exception as e:
print(e)
结论¶
在本教程中,我们学习了如何使用 torch.compiler.set_stance
API 来在不需要重新应用 torch.compile
的情况下修改对模型的多次调用中的编译行为。教程展示了将 torch.compiler.set_stance
作为装饰器、上下文管理器或普通函数使用,以控制 force_eager
、default
、eager_on_recompile
和 fail_on_recompile
等编译状态。
更多信息,参见:torch.compiler.set_stance API 文档。
脚本的总运行时间: (0分钟 0.000秒)