Shortcuts

torch.compile 中的编译时缓存

Created On: Jun 20, 2024 | Last Updated: Feb 27, 2025 | Last Verified: Nov 05, 2024

作者: Oguz Ulgen

简介

PyTorch 编译器提供了多种缓存功能,用于减少编译延迟。本教程将详细解释这些功能,以帮助用户选择其用例的最佳选项。

查看 Compile Time Caching Configurations 了解如何配置这些缓存。

还可以查看我们的缓存基准测试 PT CacheBench Benchmarks

前提条件

开始此教程之前,请确保您具有以下内容:

缓存功能

torch.compile 提供了以下缓存功能:

  • 端到端缓存(也称为 Mega-Cache

  • TorchDynamoTorchInductorTriton 的模块化缓存

需要注意的是,缓存会验证缓存工件是否与相同的 PyTorch 和 Triton 版本,以及在设备设置为 CUDA 时与相同的 GPU 一起使用。

torch.compile 端到端缓存(Mega-Cache

端到端缓存,从此之后称为 Mega-Cache,是为寻求可移植缓存解决方案的用户设计的理想选择,该解决方案可以存储在数据库中,并可能在另一台机器上获取。

Mega-Cache 提供了两个编译器 API:

  • torch.compiler.save_cache_artifacts()

  • torch.compiler.load_cache_artifacts()

预期用例是在编译并执行模型后,用户调用 torch.compiler.save_cache_artifacts(),它将以可移植的形式返回编译器工件。稍后,可能在另一台机器上,用户可以使用这些工件调用 torch.compiler.load_cache_artifacts(),以预填充 torch.compile 缓存,从而快速启动其缓存。

请考虑以下示例。首先,编译并保存缓存工件。

@torch.compile
def fn(x, y):
    return x.sin() @ y

a = torch.rand(100, 100, dtype=dtype, device=device)
b = torch.rand(100, 100, dtype=dtype, device=device)

result = fn(a, b)

artifacts = torch.compiler.save_cache_artifacts()

assert artifacts is not None
artifact_bytes, cache_info = artifacts

# Now, potentially store artifact_bytes in a database
# You can use cache_info for logging

稍后,您可以通过以下方式快速启动缓存:

# Potentially download/fetch the artifacts from the database
torch.compiler.load_cache_artifacts(artifact_bytes)

此操作会填充将要讨论的下一节中的所有模块缓存,包括 PGOAOTAutogradInductorTritonAutotuning

TorchDynamoTorchInductorTriton 的模块化缓存

上述 Mega-Cache 由可以无需用户干预而单独使用的组件组成。默认情况下,PyTorch 编译器自带适用于 TorchDynamoTorchInductorTriton 的本地磁盘缓存。这些缓存包括:

  • FXGraphCache:用于编译的基于图的 IR 组件的缓存。

  • TritonCache:包括由 Triton 生成的 cubin 文件及其他缓存工件的 Triton 编译结果缓存。

  • InductorCacheFXGraphCacheTriton 缓存的集合。

  • AOTAutogradCache:联合图工件的缓存。

  • PGO-cache:动态形状决策的缓存,用于减少重新编译的次数。

所有这些缓存工件都会写入 TORCHINDUCTOR_CACHE_DIR,默认情况下该目录为 /tmp/torchinductor_myusername

远程缓存

对于希望利用基于 Redis 的缓存的用户,我们还提供了一种远程缓存选项。查看 Compile Time Caching Configurations 了解如何启用基于 Redis 的缓存。

结论

在此教程中,我们了解到,PyTorch Inductor 的缓存机制通过利用本地和远程缓存,大大减少了编译延迟。这些缓存无缝运行,无需用户干预。

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源