• Tutorials >
  • 编译自动求导:捕获 torch.compile 的更大反向图
Shortcuts

编译自动求导:捕获 torch.compile 的更大反向图

Created On: Oct 09, 2024 | Last Updated: Oct 23, 2024 | Last Verified: Oct 09, 2024

作者: Simon Fan

What you will learn
  • 编译自动求导如何与 torch.compile 交互

  • 如何使用编译自动求导 API

  • 如何使用 TORCH_LOGS 检查日志

Prerequisites

概述

编译自动求导是 PyTorch 2.4 中引入的 torch.compile 扩展,允许捕获更大的反向图。

尽管 torch.compile 能捕获反向图,但它是 部分 捕获。AOTAutograd 组件提前捕获反向图,但存在一定限制:

  • 正向中的图断点会导致反向中的图断点

  • 反向钩子 不被捕获

编译自动求导通过直接与自动求导引擎集成解决了这些限制,使其能够在运行时捕获完整的反向图。具有以下两种特点的模型应尝试编译自动求导,并可能观察到更好的性能。

然而,编译自动求导也引入了自身的限制:

  • 在反向开始时增加了查找缓存的运行时开销

  • 由于较大的捕获,更容易在 Dynamo 中发生重新编译和图断点

备注

编译自动求导仍在积极开发中,与现有 PyTorch 功能尚未完全兼容。有关某项功能的最新状态,请参阅 编译自动求导主页

设置

在本教程中,我们将基于这个简单的神经网络模型进行示例。它接受一个10维输入向量,通过一个单线性层处理,并输出另一个10维向量。

import torch

class Model(torch.nn.Module):
   def __init__(self):
      super().__init__()
      self.linear = torch.nn.Linear(10, 10)

   def forward(self, x):
      return self.linear(x)

基础用法

在调用 torch.compile API 之前,请确保将 torch._dynamo.config.compiled_autograd 设置为 True

model = Model()
x = torch.randn(10)

torch._dynamo.config.compiled_autograd = True
@torch.compile
def train(model, x):
   loss = model(x).sum()
   loss.backward()

train(model, x)

在上面的代码中,我们创建了一个 Model 类的实例,并通过使用 torch.randn(10) 生成一个随机的10维张量 x。我们定义了训练循环函数 train 并使用 @torch.compile 装饰以优化其执行。当调用 train(model, x) 时:

  • 由于调用被 @torch.compile 装饰,Python 解释器调用 Dynamo。

  • Dynamo 拦截Python字节码,模拟其执行并将操作记录到图中。

  • AOTDispatcher 禁用钩子并调用自动求导引擎计算 model.linear.weightmodel.linear.bias 的梯度,并将操作记录到图中。使用 torch.autograd.Function,AOTDispatcher 重写了 train 的正向和反向实现。

  • Inductor 生成了一个对应于 AOTDispatcher 正向和反向优化实现的函数。

  • Dynamo 设置优化后的函数,由Python解释器接下来进行评估。

  • Python解释器执行优化后的函数,该函数执行 loss = model(x).sum()

  • Python解释器执行``loss.backward()``,调用自动微分引擎,因为我们设置了``torch._dynamo.config.compiled_autograd = True``,所以路由到编译自动微分引擎。

  • 编译自动微分引擎计算``model.linear.weight``和``model.linear.bias``的梯度,并将操作记录到一个图中,包括遇到的任何钩子。在此过程中,它会记录之前通过AOTDispatcher重写的反向操作。然后,编译自动微分引擎生成一个新函数,该函数对应于``loss.backward()``的完全追踪实现,并在推理模式下用``torch.compile``执行它。

  • 相同的步骤递归地应用于编译自动微分图,但这次AOTDispatcher不需要对图进行划分。

检查编译自动微分日志

使用``TORCH_LOGS``环境变量运行脚本:

  • 如果只需打印编译的自动微分图,请使用``TORCH_LOGS=”compiled_autograd” python example.py``。

  • 如果要打印包含更多张量元数据的图以及重新编译的原因,请使用``TORCH_LOGS=”compiled_autograd_verbose” python example.py``,但性能会有所下降。

重新运行上述代码片段,编译的自动微分图现在应记录到``stderr``中。某些图节点的名称会以``aot0_``为前缀,这些节点对应于之前在AOT自动微分反向图0中编译的节点,例如``aot0_view_2``对应于id=0的AOT反向图中的``view_2``。

下面的图像中,红框包裹的是在没有编译自动微分的情况下由``torch.compile``捕获的AOT反向图。

../_images/entire_verbose_log.png

备注

这是我们将调用``torch.compile``的图,**不是**优化后的图。编译自动微分本质上生成了一些未优化的Python代码以表示整个C++自动微分执行。

使用不同标志编译正向和反向传递

你可以为两种编译使用不同的编译器配置,例如,即使正向中有图中断,反向仍可以是完整图。

def train(model, x):
    model = torch.compile(model)
    loss = model(x).sum()
    torch._dynamo.config.compiled_autograd = True
    torch.compile(lambda: loss.backward(), fullgraph=True)()

或者你可以使用上下文管理器,它会在其作用范围内应用于所有自动微分调用。

def train(model, x):
   model = torch.compile(model)
   loss = model(x).sum()
   with torch._dynamo.compiled_autograd.enable(torch.compile(fullgraph=True)):
      loss.backward()

编译自动微分解决了AOTAutograd的某些限制

  1. 正向传递中的图中断不再必然导致反向传递中的图中断:

@torch.compile(backend="aot_eager")
def fn(x):
   # 1st graph
   temp = x + 10
   torch._dynamo.graph_break()
   # 2nd graph
   temp = temp + 10
   torch._dynamo.graph_break()
   # 3rd graph
   return temp.sum()

x = torch.randn(10, 10, requires_grad=True)
torch._dynamo.utils.counters.clear()
loss = fn(x)

# 1. base torch.compile
loss.backward(retain_graph=True)
assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 3)
torch._dynamo.utils.counters.clear()

# 2. torch.compile with compiled autograd
with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):
   loss.backward()

# single graph for the backward
assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 1)

在第一个``torch.compile``案例中,由于在编译函数``fn``中有2个图中断,生成了3个反向图。而在第二个开启编译自动微分的``torch.compile``案例中,尽管有图中断,仍然追踪到了一个完整的反向图。

备注

在追踪由编译自动微分捕获的反向钩子时,Dynamo仍可能发生图中断。

  1. 现在可以捕获反向钩子

@torch.compile(backend="aot_eager")
def fn(x):
   return x.sum()

x = torch.randn(10, 10, requires_grad=True)
x.register_hook(lambda grad: grad+10)
loss = fn(x)

with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):
   loss.backward()

图中应有一个``call_hook``节点,随后dynamo会将其内联为以下内容:

../_images/call_hook_node.png

编译自动微分常见的重新编译原因

  1. 由于损失值的自动微分结构发生变化:

torch._dynamo.config.compiled_autograd = True
x = torch.randn(10, requires_grad=True)
for op in [torch.add, torch.sub, torch.mul, torch.div]:
   loss = op(x, x).sum()
   torch.compile(lambda: loss.backward(), backend="eager")()

在上面的例子中,我们每次迭代调用不同的操作,导致``loss``每次都跟踪不同的自动微分历史。你应该会看到一些重新编译的消息:由于新的自动微分节点导致缓存未命中

../_images/recompile_due_to_node.png
  1. 由于张量形状变化:

torch._dynamo.config.compiled_autograd = True
for i in [10, 100, 10]:
   x = torch.randn(i, i, requires_grad=True)
   loss = x.sum()
   torch.compile(lambda: loss.backward(), backend="eager")()

在上面的例子中,``x``的形状发生变化,编译自动微分将在第一个变化后将``x``标记为动态形状张量。你应该会看到重新编译的消息:由于形状变化导致缓存未命中

../_images/recompile_due_to_dynamic.png

结论

在本教程中,我们介绍了``torch.compile``与编译自动微分相关的高层生态系统、编译自动微分的基本知识以及一些常见的重新编译原因。请关注`开发讨论 <https://dev-discuss.pytorch.org/>`_中的深入内容。

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源