PyTorch 入门指南
学习 PyTorch
图像和视频
音频
后端
强化学习
在生产环境中部署 PyTorch 模型
Profiling PyTorch
代码变换与FX
前端API
扩展 PyTorch
模型优化
并行和分布式训练
边缘端的 ExecuTorch
推荐系统
多模态

自定义函数的双重反向传播

在某些情况下,两次反向遍历计算图是有用的,例如计算高阶梯度。然而,支持双重反向传播需要对自动求导(autograd)有一定的理解,并且需要格外小心。那些支持单次反向传播的函数并不一定能够支持双重反向传播。在本教程中,我们将展示如何编写一个支持双重反向传播的自定义 autograd 函数,并指出一些需要注意的事项。

在编写一个支持双重反向传播的自定义 autograd 函数时,了解自定义函数中的操作何时被 autograd 记录、何时不被记录,以及最重要的是 save_for_backward 如何与这一切配合使用,是非常重要的。

自定义函数以两种方式隐式地影响梯度模式:

  • 在前向传播过程中,autograd 不会记录在前向函数内执行的任何操作的图。当前向传播完成后,自定义函数的反向函数将成为前向传播输出的 grad_fn。

  • 在反向传播过程中,如果指定了 create_graph,autograd 会记录用于计算反向传播的计算图。

接下来,为了理解 save_for_backward 如何与上述内容交互,我们可以探讨几个示例:

保存输入

考虑这个简单的平方函数。它保存了一个输入张量用于反向传播。当 autograd 能够在反向传播过程中记录操作时,双反向传播会自动工作,因此通常在我们保存一个输入用于反向传播时,无需担心,因为如果该输入是任何需要梯度的张量的函数,它应该具有 grad_fn。这使得梯度能够正确传播。

importtorch

classSquare(torch.autograd.Function):
    @staticmethod
    defforward(ctx, x):
        # Because we are saving one of the inputs use `save_for_backward`
        # Save non-tensors and non-inputs/non-outputs directly on ctx
        ctx.save_for_backward(x)
        return x**2

    @staticmethod
    defbackward(ctx, grad_out):
        # A function support double backward automatically if autograd
        # is able to record the computations performed in backward
        x, = ctx.saved_tensors
        return grad_out * 2 * x

# Use double precision because finite differencing method magnifies errors
x = torch.rand(3, 3, requires_grad=True, dtype=torch.double)
torch.autograd.gradcheck(Square.apply, x)
# Use gradcheck to verify second-order derivatives
torch.autograd.gradgradcheck(Square.apply, x)

我们可以使用 torchviz 来可视化图形,以理解其为何有效。

importtorchviz

x = torch.tensor(1., requires_grad=True).clone()
out = Square.apply(x)
grad_x, = torch.autograd.grad(out, x, create_graph=True)
torchviz.make_dot((grad_x, x, out), {"grad_x": grad_x, "x": x, "out": out})

我们可以看到,关于 x 的梯度本身是 x 的函数(dout/dx = 2x),并且该函数的图形已经正确构建。

https://user-images.githubusercontent.com/13428986/126559699-e04f3cb1-aaf2-4a9a-a83d-b8767d04fbd9.png

保存输出

前一个例子的一个细微变化是保存输出而不是输入。其机制类似,因为输出也关联着一个 grad_fn。

classExp(torch.autograd.Function):
    # Simple case where everything goes well
    @staticmethod
    defforward(ctx, x):
        # This time we save the output
        result = torch.exp(x)
        # Note that we should use `save_for_backward` here when
        # the tensor saved is an ouptut (or an input).
        ctx.save_for_backward(result)
        return result

    @staticmethod
    defbackward(ctx, grad_out):
        result, = ctx.saved_tensors
        return result * grad_out

x = torch.tensor(1., requires_grad=True, dtype=torch.double).clone()
# Validate our gradients using gradcheck
torch.autograd.gradcheck(Exp.apply, x)
torch.autograd.gradgradcheck(Exp.apply, x)

使用 torchviz 可视化计算图:

out = Exp.apply(x)
grad_x, = torch.autograd.grad(out, x, create_graph=True)
torchviz.make_dot((grad_x, x, out), {"grad_x": grad_x, "x": x, "out": out})

https://user-images.githubusercontent.com/13428986/126559780-d141f2ba-1ee8-4c33-b4eb-c9877b27a954.png

保存中间结果

一个更复杂的情况是当我们需要保存一个中间结果时。我们通过以下实现来演示这种情况:

\[sinh(x) := \frac{e^x - e^{-x}}{2} \]

由于 sinh 的导数是 cosh,因此在反向计算中重用前向计算中的两个中间结果 exp(x) 和 exp(-x) 可能会有帮助。

然而,不应该直接保存并使用前向计算中的中间结果来进行反向计算。因为前向计算是在无梯度模式下执行的,如果使用前向计算的中间结果来计算反向传播中的梯度,那么梯度的反向计算图将不会包含计算该中间结果的操作。这会导致梯度计算错误。

classSinh(torch.autograd.Function):
    @staticmethod
    defforward(ctx, x):
        expx = torch.exp(x)
        expnegx = torch.exp(-x)
        ctx.save_for_backward(expx, expnegx)
        # In order to be able to save the intermediate results, a trick is to
        # include them as our outputs, so that the backward graph is constructed
        return (expx - expnegx) / 2, expx, expnegx

    @staticmethod
    defbackward(ctx, grad_out, _grad_out_exp, _grad_out_negexp):
        expx, expnegx = ctx.saved_tensors
        grad_input = grad_out * (expx + expnegx) / 2
        # We cannot skip accumulating these even though we won't use the outputs
        # directly. They will be used later in the second backward.
        grad_input += _grad_out_exp * expx
        grad_input -= _grad_out_negexp * expnegx
        return grad_input

defsinh(x):
    # Create a wrapper that only returns the first output
    return Sinh.apply(x)[0]

x = torch.rand(3, 3, requires_grad=True, dtype=torch.double)
torch.autograd.gradcheck(sinh, x)
torch.autograd.gradgradcheck(sinh, x)

使用 torchviz 可视化图表:

out = sinh(x)
grad_x, = torch.autograd.grad(out.sum(), x, create_graph=True)
torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out})

https://user-images.githubusercontent.com/13428986/126560494-e48eba62-be84-4b29-8c90-a7f6f40b1438.png

保存中间结果:不应采取的做法

现在我们展示一下如果不将中间结果作为输出返回会发生什么情况:grad_x 甚至不会有反向计算图,因为它纯粹是 exp 和 expnegx 的函数,而这些函数不需要梯度。

classSinhBad(torch.autograd.Function):
    # This is an example of what NOT to do!
    @staticmethod
    defforward(ctx, x):
        expx = torch.exp(x)
        expnegx = torch.exp(-x)
        ctx.expx = expx
        ctx.expnegx = expnegx
        return (expx - expnegx) / 2

    @staticmethod
    defbackward(ctx, grad_out):
        expx = ctx.expx
        expnegx = ctx.expnegx
        grad_input = grad_out * (expx + expnegx) / 2
        return grad_input

使用 torchviz 可视化计算图。注意 grad_x 并不在图中!

out = SinhBad.apply(x)
grad_x, = torch.autograd.grad(out.sum(), x, create_graph=True)
torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out})

https://user-images.githubusercontent.com/13428986/126565889-13992f01-55bc-411a-8aee-05b721fe064a.png

当反向传播未被追踪时

最后,我们来看一个例子,说明在某些情况下,autograd 可能无法跟踪函数反向传播的梯度。我们可以想象 cube_backward 是一个可能需要使用非 PyTorch 库(如 SciPy 或 NumPy)的函数,或者它是用 C++ 扩展编写的。这里的解决方案是创建另一个自定义函数 CubeBackward,并在其中手动指定 cube_backward 的反向传播逻辑!

defcube_forward(x):
    return x**3

defcube_backward(grad_out, x):
    return grad_out * 3 * x**2

defcube_backward_backward(grad_out, sav_grad_out, x):
    return grad_out * sav_grad_out * 6 * x

defcube_backward_backward_grad_out(grad_out, x):
    return grad_out * 3 * x**2

classCube(torch.autograd.Function):
    @staticmethod
    defforward(ctx, x):
        ctx.save_for_backward(x)
        return cube_forward(x)

    @staticmethod
    defbackward(ctx, grad_out):
        x, = ctx.saved_tensors
        return CubeBackward.apply(grad_out, x)

classCubeBackward(torch.autograd.Function):
    @staticmethod
    defforward(ctx, grad_out, x):
        ctx.save_for_backward(x, grad_out)
        return cube_backward(grad_out, x)

    @staticmethod
    defbackward(ctx, grad_out):
        x, sav_grad_out = ctx.saved_tensors
        dx = cube_backward_backward(grad_out, sav_grad_out, x)
        dgrad_out = cube_backward_backward_grad_out(grad_out, x)
        return dgrad_out, dx

x = torch.tensor(2., requires_grad=True, dtype=torch.double)

torch.autograd.gradcheck(Cube.apply, x)
torch.autograd.gradgradcheck(Cube.apply, x)

使用 torchviz 可视化图形:

out = Cube.apply(x)
grad_x, = torch.autograd.grad(out, x, create_graph=True)
torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out})

https://user-images.githubusercontent.com/13428986/126559935-74526b4d-d419-4983-b1f0-a6ee99428531.png

总结来说,自定义函数是否支持 double backward 操作,完全取决于其反向传播过程能否被 autograd 追踪。通过前两个示例,我们展示了那些开箱即用就支持 double backward 的情形。而在第三和第四个示例中,我们则演示了如何通过特定技术手段,使得原本无法被追踪的反向传播函数变得可追踪。

本页目录