自动求导机制
本笔记将概述自动求导的工作原理以及如何记录操作。虽然严格来说并不需要完全理解这些内容,但我们建议您熟悉它们,因为这将帮助您编写更高效、更清晰的程序,并在调试时提供帮助。
自动求导如何编码历史
自动求导是一种基于反向传播的自动微分系统。从概念上讲,自动求导在执行操作时会记录一个图,该图记录了所有生成数据的操作,形成一个有向无环图,其中叶子节点是输入张量,根节点是输出张量。通过从根节点追溯到叶子节点,可以利用链式法则自动计算梯度。
在内部实现中,自动求导将这个图表示为 Function
对象(实际上是表达式)的图,这些对象可以通过 apply()
方法调用来计算图的评估结果。在计算前向传播时,自动求导同时执行请求的计算并构建一个表示梯度计算函数的图(每个 torch.Tensor
的 .grad_fn
属性是进入该图的入口点)。当前向传播完成后,我们在反向传播中评估这个图以计算梯度。
需要注意的一点是,图在每次迭代时都会完全重新构建,这正是能够使用任意 Python 控制流语句的原因,这些语句可以在每次迭代时改变图的总体结构和大小。你不必在开始训练之前编码所有可能的路径——你运行的内容就是你要进行微分的部分。
保存的张量
某些操作需要在前向传播过程中保存中间结果,以便执行反向传播。例如,函数 (x \mapsto x^2) 会保存输入 (x) 以计算梯度。
在定义自定义 Python Function
时,可以使用 save_for_backward()
在前向传播过程中保存张量,并使用 saved_tensors
在反向传播过程中检索它们。更多信息请参见 扩展 PyTorch。
对于 PyTorch 定义的操作(例如 torch.pow()
),张量会根据需要自动保存。你可以通过查找某个 grad_fn
的以 _saved
开头的属性来了解(出于教育或调试目的)哪些张量被保存。
x = torch.randn(5, requires_grad=True)
y = x.pow(2)
print(x.equal(y.grad_fn._saved_self)) # True
print(x is y.grad_fn._saved_self) # True
在上述代码中,y.grad_fn._saved_self
引用的是与 x
相同的张量对象。但情况并非总是如此。例如:
x = torch.randn(5, requires_grad=True)
y = x.exp()
print(y.equal(y.grad_fn._saved_result)) # True
print(y is y.grad_fn._saved_result) # False
在内部,为了防止引用循环,PyTorch 在保存时会将张量 打包,并在读取时将其 解包 为另一个张量对象。在这里,从 y.grad_fn._saved_result
访问到的张量是一个与 y
不同的张量对象(但它们仍然共享相同的存储)。
一个张量是否会打包成不同的张量对象取决于它是否是其自身 grad_fn
的输出。这是一个实现细节,可能会发生变化,因此用户不应依赖于此。
您可以使用 保存张量的钩子 来控制 PyTorch 如何进行打包和解包。
非可微函数的梯度
使用自动求导计算梯度仅在每个基本函数都是可微的情况下有效。不幸的是,我们在实际中使用的许多函数并不具备这一性质(例如 relu
或 sqrt
在 0
处)。为了尽量减少非可微函数的影响,我们通过按顺序应用以下规则来定义基本操作的梯度:
-
如果函数可微并且在当前点存在梯度,则使用它。
-
如果函数是凸的(至少局部是凸的),则使用最小范数的次梯度(这是最陡峭的下降方向)。
-
如果函数是凹的(至少局部是凹的),则使用最小范数的超梯度(考虑 -f(x) 并应用前一条)。
-
如果函数已定义,在当前点通过连续性定义梯度(注意这里可以是
inf
,例如对于sqrt(0)
)。如果有多个可能的值,任意选择其中一个。 -
如果函数未定义(例如
sqrt(-1)
、log(-1)
或输入为NaN
的大多数函数),则用于梯度的值是任意的(我们可能会抛出错误,但不保证这样做)。大多数函数会将NaN
作为梯度,但由于性能原因,某些函数会使用其他值(例如log(-1)
)。 -
如果函数不是一个确定性的映射(即它不是一个[数学函数](https://en.wikipedia.org/wiki/Function_(mathematics))),则会被标记为不可微。这将在反向传播时导致错误,除非在
no_grad
环境中使用需要梯度的张量。
局部禁用梯度计算
Python 提供了几种机制来局部禁用梯度计算:
要禁用整个代码块的梯度计算,可以使用无梯度模式(no-grad mode)和推理模式(inference mode)等上下文管理器。对于更细粒度地排除子图的梯度计算,可以通过设置张量的 requires_grad
字段来实现。
除了上述机制外,我们还将介绍评估模式(nn.Module.eval()
)。尽管评估模式不是用来禁用梯度计算的,但由于其名称,常常与上述方法混淆。
设置 requires_grad
requires_grad
是一个标志,默认情况下为 False
,除非该张量被包裹在 nn.Parameter
中。它允许细粒度地控制子图的梯度计算,并在前向传递和后向传递中都生效:
在前向传递中,只有当至少有一个输入张量的 requires_grad
为 True
时,操作才会被记录在反向图中。在后向传递(.backward()
)中,只有当叶子张量的 requires_grad
为 True
时,其 .grad
字段才会累积梯度。
需要注意的是,尽管每个张量都有这个标志,但设置这一标志仅对叶张量(没有 grad_fn
的张量,例如 nn.Module
的参数)有意义。非叶张量(具有 grad_fn
的张量)是与反向传播图相关的张量。因此,它们的梯度将作为中间结果用于计算叶张量的梯度。根据这一定义,所有非叶张量都会自动设置 require_grad=True
。
设置 requires_grad
是你控制模型哪些部分参与梯度计算的主要方式。例如,在微调预训练模型时,你可以冻结模型的部分。
要冻结模型的部分,只需将不想更新的参数的 requires_grad
属性设置为 False
。如上所述,由于使用这些参数作为输入的计算不会在前向传播中记录,因此它们在反向传播中不会更新其 .grad
字段,因为它们不会成为反向图的一部分,这正是我们希望的结果。
因为这是一种非常常见的模式,requires_grad
也可以在模块级别使用 nn.Module.requires_grad_()
进行设置。当应用于模块时,.requires_grad_()
会对模块的所有参数生效(这些参数默认情况下 requires_grad=True
)。
梯度模式
除了设置 requires_grad
之外,还有三种梯度模式可以从 Python 中选择,这些模式会影响 PyTorch 内部 autograd 如何处理计算:默认模式(梯度模式)、无梯度模式和推理模式,所有这些模式都可以通过上下文管理器和装饰器进行切换。
| 模式 | 不记录在反向图中的操作 | 跳过额外的自动梯度跟踪开销 | 在该模式下创建的张量可以在稍后的梯度模式中使用 | 示例 | | --- | --- | --- | --- | --- | | 默认模式 | | | ✓ | 前向传播(forward pass) | | 无梯度模式 | ✓ | | ✓ | 优化器更新(optimizer updates) | | 推理模式 | ✓ | ✓ | | 数据处理(data processing)、模型评估(model evaluation) |
默认模式(梯度模式)
“默认模式”是在没有启用其他模式(如无梯度模式和推理模式)时隐式进入的模式。与“无梯度模式”相对比,默认模式有时也被称为“梯度模式”。
关于默认模式最重要的一点是,它是唯一让 requires_grad
起作用的模式。在其他两种模式中,requires_grad
始终被覆盖为 False
。
无梯度模式
在无梯度模式下,计算会像所有输入都不需要梯度一样进行。换句话说,即使有输入的 require_grad=True
,无梯度模式下的计算也不会记录在反向传播图中。
当您需要执行不应被自动求导记录的操作,但又希望在稍后的梯度模式中使用这些计算的输出时,启用无梯度模式。这个上下文管理器使得在代码块或函数中禁用梯度变得方便,无需临时将张量的 requires_grad
设置为 False
再改回 True
。
例如,在编写优化器时,无梯度模式非常有用。在执行训练更新时,您希望在不被自动求导记录的情况下就地更新参数,并且打算在下一个前向传递中使用这些更新后的参数进行梯度模式下的计算。
torch.nn.init 中的实现也依赖于无梯度模式,在初始化参数时避免自动求导跟踪,因为这些参数是直接在原位置更新的。
推理模式
推理模式是无梯度模式的一种更极端的形式。与无梯度模式类似,推理模式下的计算不会被记录在反向图中,但启用推理模式可以让 PyTorch 进一步加速模型。这种性能提升有一个缺点:在推理模式下创建的张量在退出推理模式后无法用于自动求导记录的计算中。
当您执行的计算不涉及自动求导,并且不打算在退出推理模式后使用在推理模式下创建的张量进行任何需要自动求导记录的计算时,可以启用推理模式。
建议您在代码中不需要自动求导跟踪的部分(例如数据处理和模型评估)尝试启用推理模式。如果它能直接适用于您的用例,那么这是一个免费的性能提升。如果您在启用推理模式后遇到错误,请检查是否在退出推理模式后使用了在推理模式下创建的张量。这些张量不能用于自动求导记录的计算。如果无法避免这种情况,您可以随时切换回无梯度模式。
有关推理模式的详细信息,请参阅 推理模式。
有关推理模式的实现细节,请参阅 RFC-0011-InferenceMode。
评估模式 (nn.Module.eval()
)
评估模式并不是一种局部禁用梯度计算的机制。尽管如此,这里还是包括了它,因为有时人们会误认为它是一种这样的机制。
功能上,module.eval()
(或等效的 module.train(False)
)与无梯度计算模式和推理模式完全独立。model.eval()
如何影响你的模型完全取决于你在模型中使用的具体模块以及它们是否定义了特定于训练模式的行为。
如果你的模型依赖于像 torch.nn.Dropout
和 torch.nn.BatchNorm2d
这样的模块,这些模块在不同训练模式下可能会有不同的行为,例如避免在验证数据上更新 BatchNorm 的累积统计信息,那么你需要根据需要调用 model.eval()
和 model.train()
。
建议在训练时始终使用 model.train()
,在评估模型(验证/测试)时使用 model.eval()
。即使你不确定模型是否有特定于训练模式的行为,也应如此操作,因为使用的模块可能会在未来更新,导致在训练和评估模式下的行为不同。
自动求导中的原地操作
在自动求导中支持原地操作是一个复杂的问题,我们通常不建议在大多数情况下使用它们。由于自动求导的激进缓冲区释放和重用机制,它已经非常高效,因此很少有情况会因为使用原地操作而显著减少内存使用。除非你处于严重的内存压力下,否则你可能永远不会需要使用它们。
限制原地操作适用性的两个主要原因:
-
原地操作可能会覆盖计算梯度所需的数据。
-
每个原地操作都需要重新构建计算图。非原地版本只需分配新对象并保留对旧图的引用,而原地操作则需要更改所有输入的创建者,使其指向表示此操作的
Function
。这可能会变得复杂,特别是当许多张量引用同一存储(例如通过索引或转置创建)时。如果修改输入的存储被其他任何Tensor
引用,原地函数将抛出错误。
原地正确性检查
每个张量都维护一个版本计数器,每次在任何操作中被修改时都会递增。当一个函数保存任何用于反向传播的张量时,也会保存这些张量的版本计数器。一旦访问 self.saved_tensors
,系统会检查版本计数器。如果当前版本计数器大于保存的值,就会抛出错误。这确保了如果您使用原地操作并且没有看到任何错误,您可以确信计算的梯度是正确的。
多线程自动求导
自动求导引擎负责运行所有必要的反向操作以计算反向传播。本节将描述所有有助于您在多线程环境中充分利用它的细节。(这仅适用于 PyTorch 1.6+,因为之前的版本行为不同。)
用户可以使用多线程代码(例如 Hogwild 训练方法)训练他们的模型,并且不会阻塞并发的反向计算。示例如下:
# 定义一个训练函数,用于不同线程
def train_fn():
x = torch.ones(5, 5, requires_grad=True)
# 前向计算
y = (x + 3) * (x + 4) * 0.5
# 反向计算
y.sum().backward()
# 优化器更新
# 用户编写线程代码来驱动 train_fn
threads = []
for _ in range(10):
p = threading.Thread(target=train_fn, args=())
p.start()
threads.append(p)
# 等待所有线程完成
for p in threads:
p.join()
请注意用户应了解的一些行为:
CPU 并发
当您在多个线程中通过 Python 或 C++ API 调用 backward()
或 grad()
时,您期望看到更多的并发性,而不是在执行过程中按特定顺序串行处理所有的反向调用(这是 PyTorch 1.6 之前的默认行为)。
非确定性
如果您从多个线程并发调用 backward()
并且有共享输入(例如 Hogwild CPU 训练),则应预期非确定性。这可能会发生,因为参数在各线程之间自动共享,因此多个线程可能会访问并尝试在梯度累积期间累加同一个 .grad
属性。这种操作在技术上是不安全的,可能会导致竞争条件,从而使结果无效。
开发具有共享参数的多线程模型的用户应考虑线程模型,并理解上述问题。
可以使用函数式 API torch.autograd.grad()
来计算梯度,以避免非确定性。
图保留
如果自动求导图的部分在多个线程间共享,例如,首先在一个线程中运行前向计算的第一部分,然后在多个线程中运行第二部分,那么第一部分的图是共享的。在这种情况下,不同的线程在同一个图上执行 grad()
或 backward()
可能会导致一个线程在执行过程中销毁图,从而导致其他线程崩溃。自动求导会像调用 backward()
两次但未设置 retain_graph=True
时那样报错,并告知用户应使用 retain_graph=True
。
自动求导节点的线程安全性
由于自动求导允许调用线程驱动其反向执行以实现潜在的并行性,确保在 CPU 上使用并行 backward()
调用时,共享部分或全部 GraphTask 的线程安全性非常重要。
自定义的 Python autograd.Function
由于 GIL 而自动线程安全。对于内置的 C++ 自动求导节点(如 AccumulateGrad、CopySlices)和自定义的 autograd::Function
,自动求导引擎使用线程互斥锁来确保可能有状态读写操作的自动求导节点的线程安全性。
C++ 钩子没有线程安全性
Autograd 不保证 C++ 钩子的线程安全性。如果希望钩子在多线程环境中正常工作,您需要编写适当的线程同步代码以确保线程安全。
复数的 Autograd
简短版本:
-
当您使用 PyTorch 对复数域和/或复数值的函数 $f(z)$ 进行微分时,梯度是在假设该函数是更大实值损失函数 $g(input)=L$ 的一部分的情况下计算的。计算出的梯度是 $\frac{\partial L}{\partial z^*}$(注意 z 的共轭形式),其负值正是梯度下降算法中使用的最陡下降方向。因此,现有的优化器可以直接用于复数参数。
-
这个约定与 TensorFlow 的复数微分约定一致,但与 JAX 计算 $\frac{\partial L}{\partial z}$ 的方式不同。
-
如果您有一个从实数到实数的函数,即使内部使用了复数运算,这里的约定也不影响结果:您总是会得到与仅使用实数运算实现时相同的结果。
如果你对数学细节感到好奇,或者想知道如何在 PyTorch 中定义复数导数,请继续阅读。
什么是复数导数?
复数可微性的数学定义将导数的极限定义推广以适用于复数。考虑一个函数 $f: ℂ → ℂ$,
$f(z=x+yj) = u(x, y) + v(x, y)j$
其中 $u$ 和 $v$ 是两个变量的实值函数,$j$ 是虚数单位。
使用导数的定义,我们可以写成:
$f'(z) = \lim_{h \to 0, h \in C} \frac{f(z+h) - f(z)}{h}$
为了使这个极限存在,不仅 $u$ 和 $v$ 必须是实可微的,而且 $f$ 还必须满足柯西-黎曼方程。换句话说:对于实部和虚部的步长($h$)计算出的极限必须相等。这是一个更严格的条件。
复数可微函数通常被称为全纯函数。它们性质良好,具有你在实可微函数中见过的所有良好性质,但在优化领域中实际应用较少。对于优化问题,研究社区主要使用实值目标函数,因为复数不属于任何有序域,因此复数值损失在实际应用中意义不大。
事实证明,没有重要的实值目标函数满足柯西-黎曼方程。因此,复解析函数的理论无法用于优化,大多数人因此使用维廷格微积分。
维廷格微积分登场……
所以,我们有这个伟大的复可微性和复解析函数理论,但完全无法使用,因为许多常用的函数都不是复解析的。一个数学家该怎么办呢?好吧,维廷格观察到,即使 $f(z)$ 不是复解析的,也可以将其重写为两个变量的函数 $f(z, z^)$,这样可以使它成为复解析的。这是因为 $z$ 的实部和虚部可以用 $z$ 和 $z^$ 表示为:
$\begin{aligned} \mathrm{Re}(z) &= \frac {z + z^}{2} \\ \mathrm{Im}(z) &= \frac {z - z^}{2j} \end{aligned}$
维廷格微积分建议研究 $f(z, z^)$ 而不是 $f(z)$。如果 $f$ 是实可微的,那么 $f(z, z^)$ 一定是复解析的。另一种思考方式是将其视为从 $f(x, y)$ 到 $f(z, z^)$ 的坐标变换。这个函数有偏导数 $\frac{\partial }{\partial z}$ 和 $\frac{\partial}{\partial z^}$。我们可以使用链式法则来建立这些偏导数与 $z$ 的实部和虚部的偏导数之间的关系。
从上述方程中,我们得到:
$\begin{aligned} \frac{\partial }{\partial z} &= 1/2 * \left(\frac{\partial }{\partial x} - i * \frac{\partial }{\partial y}\right) \\ \frac{\partial }{\partial z^*} &= 1/2 * \left(\frac{\partial }{\partial x} + i * \frac{\partial }{\partial y}\right) \end{aligned}$
这正是经典的 Wirtinger 微积分定义,你可以在 Wikipedia 上找到。
这一变化带来了许多美妙的结果。
-
首先,Cauchy-Riemann 方程可以简化为 $\frac{\partial f}{\partial z^*} = 0$(即,函数 $f$ 可以完全用 $z$ 来表示,而不需要涉及 $z^*$)。
-
另一个重要(且有些反直觉)的结果是,当我们对实值损失进行优化时,更新变量时应采取的步骤由 $\frac{\partial Loss}{\partial z^*}$ 给出(而不是 $\frac{\partial Loss}{\partial z}$)。
如需进一步阅读,请参阅:https://arxiv.org/pdf/0906.4835.pdf
Wirtinger 微积分在优化中的应用?
音频和其他领域的研究人员通常使用梯度下降来优化具有复数变量的实值损失函数。通常,他们会将实部和虚部分别视为可以单独更新的通道。对于步长 $\alpha/2$ 和损失 $L$,我们可以在 $ℝ^2$ 中写出以下方程:
$\begin{aligned} x_{n+1} &= x_n - (\alpha/2) * \frac{\partial L}{\partial x} \\ y_{n+1} &= y_n - (\alpha/2) * \frac{\partial L}{\partial y} \end{aligned}$
这些方程如何转换到复数空间 $ℂ$ 中?
$\begin{aligned} z_{n+1} &= x_n - (\alpha/2) * \frac{\partial L}{\partial x} + 1j * (y_n - (\alpha/2) * \frac{\partial L}{\partial y}) \\ &= z_n - \alpha * 1/2 * \left(\frac{\partial L}{\partial x} + j \frac{\partial L}{\partial y}\right) \\ &= z_n - \alpha * \frac{\partial L}{\partial z^*} \end{aligned}$
这里有一个重要的观察:Wirtinger 微积分告诉我们,我们可以简化上述复数变量更新公式,使其仅引用共轭 Wirtinger 导数 $\frac{\partial L}{\partial z^*}$,这正是我们在优化中采取的步骤。
由于共轭 Wirtinger 导数为我们提供了真实值损失函数的确切步长,因此当您对具有真实值损失的函数进行微分时,PyTorch 会为您提供这个导数。
PyTorch 如何计算共轭 Wirtinger 导数?
通常,我们的导数公式将 grad_output 作为输入,表示我们已经计算的传入向量-雅可比积,即 $\frac{\partial L}{\partial s^}$,其中 $L$ 是整个计算的真实损失,而 $s$ 是我们函数的输出。这里的目标是计算 $\frac{\partial L}{\partial z^}$,其中 $z$ 是函数的输入。在实际损失的情况下,我们只需要计算 $\frac{\partial L}{\partial s^*}$,尽管链式法则表明我们还需要访问 $\frac{\partial L}{\partial s}$。如果您想了解详细的推导过程,请查看本节的最后一方程,然后继续阅读下一节。
让我们继续研究定义为 $f(z) = f(x+yj) = u(x, y) + v(x, y)j$ 的复函数 $f: ℂ → ℂ$。如上所述,autograd 的梯度约定主要用于实值损失函数的优化,因此假设 $f$ 是一个更大实值损失函数 $g$ 的一部分。使用链式法则,我们可以写成:
$\frac{\partial L}{\partial z^*} = \frac{\partial L}{\partial u} * \frac{\partial u}{\partial z^*} + \frac{\partial L}{\partial v} * \frac{\partial v}{\partial z^*}$
现在利用 Wirtinger 导数的定义,我们可以写出:
$\begin{aligned} \frac{\partial L}{\partial s} = 1/2 * \left(\frac{\partial L}{\partial u} - \frac{\partial L}{\partial v} j\right) \\ \frac{\partial L}{\partial s^*} = 1/2 * \left(\frac{\partial L}{\partial u} + \frac{\partial L}{\partial v} j\right) \end{aligned}$
需要注意的是,由于 $u$ 和 $v$ 是实函数,并且根据假设 $f$ 是一个实值函数的一部分,$L$ 也是实数,因此我们有:
$\left( \frac{\partial L}{\partial s} \right)^* = \frac{\partial L}{\partial s^*}$
即,$\frac{\partial L}{\partial s}$ 等于 $grad_output^*$。
解上述方程以求 $\frac{\partial L}{\partial u}$ 和 $\frac{\partial L}{\partial v}$,我们得到:
$\begin{aligned} \frac{\partial L}{\partial u} = \frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*} \\ \frac{\partial L}{\partial v} = 1j * \left(\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}\right) \end{aligned}$
$\begin{aligned} \frac{\partial L}{\partial z^*} &= \left(\frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*}\right) * \frac{\partial u}{\partial z^*} + 1j * \left(\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}\right) * \frac{\partial v}{\partial z^*} \\ &= \frac{\partial L}{\partial s} * \left(\frac{\partial u}{\partial z^*} + \frac{\partial v}{\partial z^*} j\right) + \frac{\partial L}{\partial s^*} * \left(\frac{\partial u}{\partial z^*} - \frac{\partial v}{\partial z^*} j\right) \\ &= \frac{\partial L}{\partial s} * \frac{\partial (u + vj)}{\partial z^*} + \frac{\partial L}{\partial s^*} * \frac{\partial (u + vj)^*}{\partial z^*} \\ &= \frac{\partial L}{\partial s} * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * \frac{\partial s^*}{\partial z^*} \\ \end{aligned}$
使用 (2),我们得到:
$\begin{aligned} \frac{\partial L}{\partial z^*} &= \left(\frac{\partial L}{\partial s^*}\right)^* * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * \left(\frac{\partial s}{\partial z}\right)^* \\ &= \boxed{ (输出梯度)^* * \frac{\partial s}{\partial z^*} + 输出梯度 * \left(\frac{\partial s}{\partial z}\right)^* } \\ \end{aligned}$
最后一个方程对于编写自己的梯度公式非常重要,因为它将导数公式分解为一个更简单的形式,易于手动计算。
如何为复函数编写自己的导数公式?
上述方框中的方程给出了所有复函数导数的一般公式。然而,我们仍然需要计算 $\frac{\partial s}{\partial z}$ 和 $\frac{\partial s}{\partial z^*}$。有两种方法可以做到这一点:
第一种方法是直接使用 Wirtinger 导数的定义,并通过 $\frac{\partial s}{\partial x}$ 和 $\frac{\partial s}{\partial y}$(你可以用常规方法计算)来计算 $\frac{\partial s}{\partial z}$ 和 $\frac{\partial s}{\partial z^*}$。
第二种方法是使用变量替换技巧,将 $f(z)$ 重写为两个变量的函数 $f(z, z^*)$,并通过将 $z$ 和 $z^*$ 视为独立变量来计算共轭 Wirtinger 导数。这通常更容易;例如,如果所讨论的函数是解析的,则只会使用 $z$(并且 $\frac{\partial s}{\partial z^*}$ 将为零)。
让我们以函数 $f(z = x + yj) = c * z = c * (x+yj)$ 为例,其中 $c \in ℝ$。
使用第一种方法计算 Wirtinger 导数,我们得到:
$\begin{aligned} \frac{\partial s}{\partial z} &= 1/2 * \left(\frac{\partial s}{\partial x} - \frac{\partial s}{\partial y} j\right) \\ &= 1/2 * (c - (c * 1j) * 1j) \\ &= c \\ \\ \\ \frac{\partial s}{\partial z^*} &= 1/2 * \left(\frac{\partial s}{\partial x} + \frac{\partial s}{\partial y} j\right) \\ &= 1/2 * (c + (c * 1j) * 1j) \\ &= 0 \\ \end{aligned}$
根据公式 (4),并且 grad_output = 1.0
(这是在 PyTorch 中调用 backward()
时默认的梯度输出值,用于标量输出),我们得到:
$\frac{\partial L}{\partial z^*} = 1 * 0 + 1 * c = c$
使用第二种方法计算 Wirtinger 导数,我们可以直接得到:
$\begin{aligned} \frac{\partial s}{\partial z} &= \frac{\partial (c*z)}{\partial z} \\ &= c \\ \frac{\partial s}{\partial z^*} &= \frac{\partial (c*z)}{\partial z^*} \\ &= 0 \end{aligned}$
再次使用 (4),我们得到 $\frac{\partial L}{\partial z^*} = c$。可以看出,第二种方法涉及较少的计算,并且在快速计算时更为方便。
如何处理跨域函数?
有些函数将复数输入映射到实数输出,或反之亦然。这些函数是 (4) 的特殊情况,可以通过链式法则推导得出:
- 对于 $f: ℂ → ℝ$,我们得到:
$\frac{\partial L}{\partial z^*} = 2 * grad\_output * \frac{\partial s}{\partial z^{*}}$
- 对于 $f: ℝ → ℂ$,我们得到:
$\frac{\partial L}{\partial z^*} = 2 * \mathrm{Re}(grad\_output^* * \frac{\partial s}{\partial z^{*}})$
保存张量的钩子
你可以通过定义一对 pack_hook
和 unpack_hook
钩子来控制 如何打包/解包保存的张量。pack_hook
函数接受一个张量作为参数,但可以返回任何 Python 对象,如另一个张量、元组或包含文件名的字符串。unpack_hook
函数接收 pack_hook
的输出作为参数,并返回一个用于反向传播的张量。unpack_hook
返回的张量内容只需与传递给 pack_hook
的张量相同即可。特别是,与自动梯度相关的元数据可以忽略,因为它们会在解包过程中被覆盖。
一个这样的配对示例如下:
class SelfDeletingTempFile():
def __init__(self):
self.name = os.path.join(tmp_dir, str(uuid.uuid4()))
def __del__(self):
os.remove(self.name)
def pack_hook(tensor):
temp_file = SelfDeletingTempFile()
torch.save(tensor, temp_file.name)
return temp_file
def unpack_hook(temp_file):
return torch.load(temp_file.name)
请注意,unpack_hook
不应删除临时文件,因为它可能会被多次调用:临时文件应与返回的 SelfDeletingTempFile
对象具有相同的生命周期。在上述示例中,我们在不再需要临时文件时(即在删除 SelfDeletingTempFile
对象时)将其关闭,以防止文件泄漏。
注意
我们保证 pack_hook
只会被调用一次,但 unpack_hook
可能会被反向传播过程多次调用,并且我们期望它每次返回相同的数据。
警告
对任何函数的输入执行原地操作是禁止的,因为这可能导致意外的副作用。如果对 pack_hook
的输入进行原地修改,PyTorch 会抛出错误,但不会检测到对 unpack_hook
输入进行原地修改的情况。
为保存的张量注册钩子
你可以通过调用 SavedTensor
对象上的 register_hooks()
方法来为保存的张量注册一对钩子。这些对象作为 grad_fn
的属性公开,并以 _raw_saved_
前缀开头。
x = torch.randn(5, requires_grad=True)
y = x.pow(2)
y.grad_fn._raw_saved_self.register_hooks(pack_hook, unpack_hook)
pack_hook
方法在注册这对钩子时立即调用。unpack_hook
方法在每次需要访问保存的张量时被调用,无论是通过 y.grad_fn._saved_self
还是在反向传播过程中。
警告
如果你在保存的张量被释放后(即在反向传播调用后)仍然保留对 SavedTensor
的引用,调用其 register_hooks()
是禁止的。PyTorch 在大多数情况下会抛出错误,但在某些情况下可能会失败并导致未定义行为。
为所有保存的张量注册默认钩子
或者,你可以使用上下文管理器 saved_tensors_hooks
来注册一对钩子,这些钩子将应用于该上下文中创建的所有保存的张量。
示例:
# 仅将元素数量 >= 1000 的张量保存到磁盘
SAVE_ON_DISK_THRESHOLD = 1000
def pack_hook(x):
if x.numel() < SAVE_ON_DISK_THRESHOLD:
return x
temp_file = SelfDeletingTempFile()
torch.save(x, temp_file.name)
return temp_file
def unpack_hook(tensor_or_sctf):
if isinstance(tensor_or_sctf, torch.Tensor):
return tensor_or_sctf
return torch.load(tensor_or_sctf.name)
class Model(nn.Module):
def forward(self, x):
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
# ... 计算输出
output = x
return output
model = Model()
net = nn.DataParallel(model)
使用此上下文管理器定义的钩子是线程局部的,即每个线程都有自己的钩子实例。因此,以下代码不会产生预期的效果,因为这些钩子不会在 DataParallel 中生效。
# 示例:不要这样做
net = nn.DataParallel(model)
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
output = net(input) # input 是输入张量
请注意,使用这些钩子会禁用现有的减少 Tensor 对象创建的优化。例如:
with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
x = torch.randn(5, requires_grad=True)
y = x * x
如果没有这些钩子,x
、y.grad_fn._saved_self
和 y.grad_fn._saved_other
都引用同一个 Tensor 对象。使用这些钩子后,PyTorch 会创建两个新的 Tensor 对象,这两个对象与原始的 x
共享相同的存储(不执行复制)。
反向传播钩子的执行
本节将探讨不同钩子的触发条件及其触发顺序。将要讨论的钩子包括:
- 通过
torch.Tensor.register_hook()
注册到 Tensor 的反向传播钩子; - 通过
torch.Tensor.register_post_accumulate_grad_hook()
注册到 Tensor 的后累积梯度钩子; - 通过
torch.autograd.graph.Node.register_hook()
注册到节点的后钩子; - 通过
torch.autograd.graph.Node.register_prehook()
注册到节点的前钩子。
特定的钩子是否会触发
通过 torch.Tensor.register_hook()
注册到 Tensor 的钩子会在计算该 Tensor 的梯度时被触发。(需要注意,这并不需要 Tensor 的 grad_fn 被执行。例如,如果 Tensor 作为 inputs
参数的一部分传递给 torch.autograd.grad()
,即使 Tensor 的 grad_fn 不被执行,注册到该 Tensor 的钩子仍然会被触发。)
注册到 torch.autograd.graph.Node
的钩子函数(使用 torch.autograd.graph.Node.register_hook()
或 torch.autograd.graph.Node.register_prehook()
)只有在其注册的节点被运行时才会触发。
特定节点是否会被执行可能取决于反向传播是通过调用 torch.autograd.grad()
还是 torch.autograd.backward()
来进行的。具体来说,当您在 inputs
参数中传递给 torch.autograd.grad()
或 torch.autograd.backward()
的张量对应的节点上注册钩子时,应了解这些差异。
如果你使用的是 torch.autograd.backward()
,上述提到的所有钩子都会被执行,无论你是否指定了 inputs
参数。这是因为 .backward()
会执行所有节点,即使这些节点对应于作为输入指定的张量。(请注意,执行这些额外的节点通常是不必要的,但仍然会被执行。这种行为可能会发生变化;你不应依赖于此。)
另一方面,如果你使用的是 torch.autograd.grad()
,注册到与传递给 input
的张量对应的节点的反向钩子可能不会被执行,因为除非有其他输入依赖于该节点的梯度结果,否则这些节点不会被计算。
不同钩子的触发顺序
事情发生的顺序如下:
-
注册到张量的钩子被触发
-
如果节点被触发,则注册到节点的预钩子被触发
-
对于保留梯度的张量,更新其
.grad
值 -
节点被触发(受上述规则约束)
-
对于累积了
.grad
的叶张量,执行后累积梯度的钩子 -
如果节点被触发,则注册到节点的后钩子被触发
如果在同一个 Tensor 或 Node 上注册了多个相同类型的 hook,则它们会按照注册的顺序执行。后执行的 hook 可以观察到先前的 hook 对梯度所做的修改。
特殊 hook
torch.autograd.graph.register_multi_grad_hook()
是通过注册到 Tensor 的 hook 实现的。每个单独的 Tensor hook 按照上述定义的顺序触发,当最后一个 Tensor 梯度计算完成时,注册的多梯度 hook 会被调用。
torch.nn.modules.module.register_module_full_backward_hook()
是通过注册到 Node 的 hook 实现的。在前向计算过程中,hook 会被注册到与模块输入和输出对应的 grad_fn。因为模块可能有多个输入和输出,所以在前向计算之前,会在模块的输入上应用一个虚拟的自定义 autograd 函数,并在前向计算返回之前在模块的输出上应用该函数,以确保这些 Tensor 共享一个单一的 grad_fn,从而可以附加我们的 hook。
当 Tensor 被就地修改时 Tensor 钩子的行为
通常情况下,注册到 Tensor 的钩子会接收相对于该 Tensor 的输出梯度,其中 Tensor 的值是在反向传播计算时的值。
然而,如果你在一个 Tensor 上注册了钩子,然后对该 Tensor 进行就地修改,那么在就地修改之前注册的钩子同样会接收到相对于该 Tensor 的输出梯度,但 Tensor 的值会被视为就地修改之前的值。
如果你希望实现前一种行为,应该在所有就地修改之后再注册这些钩子。例如:
t = torch.tensor(1., requires_grad=True).sin()
t.cos_()
t.register_hook(fn)
t.backward()
此外,了解以下内部机制也很有帮助:当钩子注册到一个 Tensor 时,它们会永久绑定到该 Tensor 的 grad_fn。因此,即使该 Tensor 被就地修改并有了新的 grad_fn,但在就地修改之前注册的钩子仍然会与旧的 grad_fn 关联。例如,当自动求导引擎在计算图中到达该 Tensor 的旧 grad_fn 时,这些钩子会被触发。