自定义函数的双重反向¶
Created On: Aug 13, 2021 | Last Updated: Aug 13, 2021 | Last Verified: Nov 05, 2024
有时候通过反向图再次运行反向是有用的,例如计算高阶梯度。然而,支持双重反向需要对自动微分的理解以及一些谨慎。支持单次反向的函数不一定具备支持双重反向的能力。在本教程中,我们展示如何编写支持双重反向的自定义自动微分函数,并指出一些需要注意的事项。
在编写自定义自动微分函数以实现双重反向时,重要的是了解自定义函数中执行的操作何时会被自动微分记录,何时不会,最重要的是,`save_for_backward`如何与这一切交互。
自定义函数对梯度模式有两个隐性影响:
在正向中,自动微分不会记录在正向函数内部执行的任何操作的图。当正向完成时,自定义函数的反向函数成为每个正向输出的`grad_fn`。
在反向中,如果指定了create_graph,自动微分会记录用于计算反向传递的计算图。
接下来,为了理解`save_for_backward`如何与上述内容交互,我们可以探索几个例子:
保存输入¶
考虑这个简单的平方函数。它保存了一个输入张量以供后向计算使用。当autograd能够记录后向过程中的操作时,双重后向会自动运行。因此在保存一个输入用于后向计算时通常不需要担心,因为如果该输入是任何需要梯度的张量的函数,它应该具有一个grad_fn。这允许梯度被正确传播。
import torch
class Square(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
# Because we are saving one of the inputs use `save_for_backward`
# Save non-tensors and non-inputs/non-outputs directly on ctx
ctx.save_for_backward(x)
return x**2
@staticmethod
def backward(ctx, grad_out):
# A function support double backward automatically if autograd
# is able to record the computations performed in backward
x, = ctx.saved_tensors
return grad_out * 2 * x
# Use double precision because finite differencing method magnifies errors
x = torch.rand(3, 3, requires_grad=True, dtype=torch.double)
torch.autograd.gradcheck(Square.apply, x)
# Use gradcheck to verify second-order derivatives
torch.autograd.gradgradcheck(Square.apply, x)
我们可以使用torchviz来可视化计算图,从而理解为什么这会有效。
import torchviz
x = torch.tensor(1., requires_grad=True).clone()
out = Square.apply(x)
grad_x, = torch.autograd.grad(out, x, create_graph=True)
torchviz.make_dot((grad_x, x, out), {"grad_x": grad_x, "x": x, "out": out})
我们可以看到相对于变量x的梯度,实际上是x的函数 (dout/dx = 2x),并且这个函数的计算图已经被正确构造。

保存输出值¶
前一个例子的一个小变化是保存输出值而不是输入值。机制类似,因为输出值也与一个grad_fn相关联。
class Exp(torch.autograd.Function):
# Simple case where everything goes well
@staticmethod
def forward(ctx, x):
# This time we save the output
result = torch.exp(x)
# Note that we should use `save_for_backward` here when
# the tensor saved is an ouptut (or an input).
ctx.save_for_backward(result)
return result
@staticmethod
def backward(ctx, grad_out):
result, = ctx.saved_tensors
return result * grad_out
x = torch.tensor(1., requires_grad=True, dtype=torch.double).clone()
# Validate our gradients using gradcheck
torch.autograd.gradcheck(Exp.apply, x)
torch.autograd.gradgradcheck(Exp.apply, x)
使用torchviz可视化计算图:
out = Exp.apply(x)
grad_x, = torch.autograd.grad(out, x, create_graph=True)
torchviz.make_dot((grad_x, x, out), {"grad_x": grad_x, "x": x, "out": out})

保存中间结果¶
一个更复杂的情况是需要保存中间结果。我们通过实现以下情况来演示:
由于sinh的导数是cosh,在后向计算中重用`exp(x)`和`exp(-x)`这两个中间结果可能很有用。
然而,中间结果不应直接保存并在后向计算中使用。因为前向过程是在无梯度模式下执行的,如果前向过程的中间结果被用来在后向过程中计算梯度,那么梯度的后向图不会包括计算中间结果的操作。这会导致梯度计算错误。
class Sinh(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
expx = torch.exp(x)
expnegx = torch.exp(-x)
ctx.save_for_backward(expx, expnegx)
# In order to be able to save the intermediate results, a trick is to
# include them as our outputs, so that the backward graph is constructed
return (expx - expnegx) / 2, expx, expnegx
@staticmethod
def backward(ctx, grad_out, _grad_out_exp, _grad_out_negexp):
expx, expnegx = ctx.saved_tensors
grad_input = grad_out * (expx + expnegx) / 2
# We cannot skip accumulating these even though we won't use the outputs
# directly. They will be used later in the second backward.
grad_input += _grad_out_exp * expx
grad_input -= _grad_out_negexp * expnegx
return grad_input
def sinh(x):
# Create a wrapper that only returns the first output
return Sinh.apply(x)[0]
x = torch.rand(3, 3, requires_grad=True, dtype=torch.double)
torch.autograd.gradcheck(sinh, x)
torch.autograd.gradgradcheck(sinh, x)
使用torchviz可视化计算图:
out = sinh(x)
grad_x, = torch.autograd.grad(out.sum(), x, create_graph=True)
torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out})

保存中间结果:错误示例¶
现在我们展示当没有将中间结果作为输出返回时会发生什么情况:`grad_x`甚至不会有一个后向计算图,因为它是`exp`和`expnegx`的纯函数,而这两个函数不需要梯度。
class SinhBad(torch.autograd.Function):
# This is an example of what NOT to do!
@staticmethod
def forward(ctx, x):
expx = torch.exp(x)
expnegx = torch.exp(-x)
ctx.expx = expx
ctx.expnegx = expnegx
return (expx - expnegx) / 2
@staticmethod
def backward(ctx, grad_out):
expx = ctx.expx
expnegx = ctx.expnegx
grad_input = grad_out * (expx + expnegx) / 2
return grad_input
使用torchviz可视化计算图。注意`grad_x`并不在计算图的一部分!
out = SinhBad.apply(x)
grad_x, = torch.autograd.grad(out.sum(), x, create_graph=True)
torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out})

当后向过程不能被追踪时¶
最后,我们考虑一个示例,其中autograd可能完全无法追踪函数的后向梯度。可以设想`cube_backward`是一个可能需要非PyTorch库(如SciPy或NumPy)或以C++扩展编写的函数。这里展示的解决方法是创建另一个自定义函数`CubeBackward`,并手动指定`cube_backward`的后向过程!
def cube_forward(x):
return x**3
def cube_backward(grad_out, x):
return grad_out * 3 * x**2
def cube_backward_backward(grad_out, sav_grad_out, x):
return grad_out * sav_grad_out * 6 * x
def cube_backward_backward_grad_out(grad_out, x):
return grad_out * 3 * x**2
class Cube(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return cube_forward(x)
@staticmethod
def backward(ctx, grad_out):
x, = ctx.saved_tensors
return CubeBackward.apply(grad_out, x)
class CubeBackward(torch.autograd.Function):
@staticmethod
def forward(ctx, grad_out, x):
ctx.save_for_backward(x, grad_out)
return cube_backward(grad_out, x)
@staticmethod
def backward(ctx, grad_out):
x, sav_grad_out = ctx.saved_tensors
dx = cube_backward_backward(grad_out, sav_grad_out, x)
dgrad_out = cube_backward_backward_grad_out(grad_out, x)
return dgrad_out, dx
x = torch.tensor(2., requires_grad=True, dtype=torch.double)
torch.autograd.gradcheck(Cube.apply, x)
torch.autograd.gradgradcheck(Cube.apply, x)
使用torchviz可视化计算图:
out = Cube.apply(x)
grad_x, = torch.autograd.grad(out, x, create_graph=True)
torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out})

总结起来,是否双重后向适用于自定义函数仅仅取决于后向过程是否可以被autograd追踪。在前两个示例中,我们展示了双重后向自动适用的情况。在第三和第四个示例中,我们演示了使后向函数能够被追踪的技巧,否则就无法追踪。