扩展 PyTorch
在本文中,我们将介绍扩展torch.nn
、torch.autograd
、torch
,以及编写自定义C++扩展的方法。
添加新操作符
PyTorch 提供了一个庞大的操作符库,用于对张量进行操作(例如:torch.add()
,torch.sum()
等等)。然而,您可能希望将一个新的自定义操作引入 PyTorch,并使其行为与 PyTorch 的内置操作符类似。为了实现这一点,您必须通过 Python torch.library 或 C++ TORCH_LIBRARY API 将自定义操作注册到 PyTorch 中。
请参阅PyTorch自定义操作符 landing page以获取更多详细信息。
扩展torch.autograd
要向autograd
中添加操作,需要为每个操作实现一个新的Function
子类。记住,函数是autograd
用来编码操作历史和计算梯度的工具。
本文的第一部分重点介绍了反向模式自动微分,因为这是最广泛使用的功能。最后一节讨论了正向模式自动微分的扩展。
使用场景
一般来说,如果你希望在模型中执行不可微分的计算或依赖于非 PyTorch 库(例如 NumPy),但仍然希望你的操作能够与其它运算符串联并支持自动求梯度引擎,则应实现一个自定义函数。
在某些情况下,自定义函数也可以用于提升性能和内存使用效率:如果您使用C++扩展实现了前向和反向传递,则可以将它们包装在Function
中以与自动微分引擎接口。如果您希望减少为反向传递保存的缓冲区数量,自定义函数可用于将操作合并在一起。
不使用的时机
如果你已经能够用PyTorch的内置操作符来编写你的函数,那么它的反向计算图(很可能会)已经被autograd记录下来了。这种情况下,你不需要自己实现反向函数。考虑使用一个普通的Python函数。
如果你需要维护状态,即可训练参数,你应该(也)使用一个自定义模块。有关扩展torch.nn
的更多信息,请参阅下文。
如果您希望在反向传递期间修改梯度或执行副作用,请考虑注册一个tensor或Module钩子。
如何使用
按照以下步骤操作:1. 继承Function
并实现forward()
,(可选)setup_context()
和backward()
方法。2. 在ctx参数上调用适当的函数。3. 声明您的函数是否支持双倍反向传播。4. 使用gradcheck验证梯度是否正确。
步骤 1: 继承Function
后,你需要定义 3 个方法:
-
forward()
是执行操作的代码。它可以接受任意数量的参数,其中一些参数可以是可选的(如果你指定了默认值)。这里接受各种 Python 对象。带有历史记录跟踪(即requires_grad=True
)的Tensor
参数将在调用之前转换为不带历史记录跟踪的 Tensor,并且它们的使用将被注册到计算图中。请注意,此逻辑不会遍历列表/字典/任何其他数据结构,只会考虑直接作为调用参数传递的张量。你可以返回单个Tensor
输出,或者如果输出多个,则返回一个tuple
的张量。此外,请参考Function
的文档,以找到只能从forward()
调用的有用方法的描述。 -
setup_context()
(可选)。你可以编写一个接受ctx
对象的“组合”forward()
,或者(从 PyTorch 2.0 开始)一个不接受ctx
的单独forward()
和一个setup_context()
方法,其中进行ctx
修改。forward()
应该包含计算,而setup_context()
只负责修改ctx
(而不执行任何计算)。通常,分开的forward()
和setup_context()
更接近 PyTorch 原生操作的工作方式,因此与各种 PyTorch 子系统更兼容。有关更多详细信息,请参见 Combined or separate forward() and setup_context()。 -
backward()
(或vjp()
)定义了梯度公式。它将获得与输出相同数量的Tensor
参数,每个参数表示对相应输出的梯度。重要的是,绝 NEVER 对这些张量进行原地修改。它应该返回与输入相同数量的张量,每个张量包含对应输入的梯度。如果你的输入不需要梯度(needs_input_grad
是一个布尔元组,指示每个输入是否需要计算梯度),或者它们不是Tensor
对象,则可以返回python:None
。此外,如果你在forward()
中有可选参数,你可以返回比输入更多的梯度,只要它们都是None
。
步骤 2: 您有责任正确使用 ctx
中的函数,以确保新的 Function
能够与自动求导引擎正常配合工作。
-
必须使用
save_for_backward()
来保存任何用于反向传播的张量。非张量应直接存储在ctx中。如果保存了既不是输入也不是输出的张量用于反向传播,那么您的Function
可能不支持双层反向传播(参见步骤3)。 -
必须使用
mark_dirty()
来标记任何在前向函数中被原地修改的输入。 -
必须使用
mark_non_differentiable()
来告知引擎某个输出是否不可微分。默认情况下,所有可微分类型的输出张量都会被设置为需要计算梯度。对于不可微分类型(即整数类型)的张量,永远不会标记为需要梯度。 -
可以使用
set_materialize_grads()
来告知自动微分引擎在输出不依赖于输入的情况下优化梯度计算,即不会将传递给反向函数的梯度张量显式地转换为零张量。具体来说,如果设置为False,则Python中的None对象或C++中的“未定义张量”(即x.defined()为False的张量)在调用backward之前不会被转换为全零张量,因此您的代码需要将这些对象视为全零张量进行处理。默认值为True。
步骤 3:如果你的Function
不支持双倍反向传播,你应该通过用once_differentiable()
装饰器包装backward方法来显式地声明这一点。使用这个装饰器后,尝试通过对你的函数进行双倍反向传播将导致错误。有关双倍反向传播的更多信息,请参阅我们的双倍反向传播教程。
步骤 4: 建议使用 torch.autograd.gradcheck()
来验证您的反向函数是否正确计算前向传播的梯度。该方法通过利用您的反向函数计算雅可比矩阵,并将其与基于有限差分法计算得到的雅可比矩阵进行逐元素比较,从而实现这一目的。
示例
以下你可以找到带有额外注释的Linear
函数代码:
# Inherit from Function
class LinearFunction(Function):
# Note that forward, setup_context, and backward are @staticmethods
@staticmethod
def forward(input, weight, bias):
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
# inputs is a Tuple of all of the inputs passed to forward.
# output is the output of the forward().
def setup_context(ctx, inputs, output):
input, weight, bias = inputs
ctx.save_for_backward(input, weight, bias)
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
# These needs_input_grad checks are optional and there only to
# improve efficiency. If you want to make your code simpler, you can
# skip them. Returning gradients for inputs that don't require it is
# not an error.
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias
现在,为了让使用这些自定义操作更加简便,我们建议您为它们创建别名或将它们封装在函数中。将操作封装在函数中,使我们能够支持默认参数和关键字参数。
# Option 1: alias
linear = LinearFunction.apply
# Option 2: wrap in a function, to support default args and keyword args.
def linear(input, weight, bias=None):
return LinearFunction.apply(input, weight, bias)
这里,我们提供一个额外的例子,说明一个由非张量参数参数化的函数。
class MulConstant(Function):
@staticmethod
def forward(tensor, constant):
return tensor * constant
@staticmethod
def setup_context(ctx, inputs, output):
# ctx is a context object that can be used to stash information
# for backward computation
tensor, constant = inputs
ctx.constant = constant
@staticmethod
def backward(ctx, grad_output):
# We return as many input gradients as there were arguments.
# Gradients of non-Tensor arguments to forward must be None.
return grad_output * ctx.constant, None
在这里,我们通过调用set_materialize_grads(False)来优化上面的例子:
class MulConstant(Function):
@staticmethod
def forward(tensor, constant):
return tensor * constant
@staticmethod
def setup_context(ctx, inputs, output):
tensor, constant = inputs
ctx.set_materialize_grads(False)
ctx.constant = constant
@staticmethod
def backward(ctx, grad_output):
# Here we must handle None grad_output tensor. In this case we
# can skip unnecessary computations and just return None.
if grad_output is None:
return None, None
# We return as many input gradients as there were arguments.
# Gradients of non-Tensor arguments to forward must be None.
return grad_output * ctx.constant, None
如果你需要任何在forward()
中计算的“中间”张量被保存,那么要么必须将它们作为输出返回,要么结合forward
和setup_context()
(参见Combined or separate forward() and setup_context())。注意这意味着如果你想让这些中间值的梯度能够流动,你需要为它们定义梯度公式(也参见双反向教程):
class MyCube(torch.autograd.Function):
@staticmethod
def forward(x):
# We wish to save dx for backward. In order to do so, it must
# be returned as an output.
dx = 3 * x ** 2
result = x ** 3
return result, dx
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
result, dx = output
ctx.save_for_backward(x, dx)
@staticmethod
def backward(ctx, grad_output, grad_dx):
x, dx = ctx.saved_tensors
# In order for the autograd.Function to work with higher-order
# gradients, we must add the gradient contribution of `dx`,
# which is grad_dx * 6 * x.
result = grad_output * dx + grad_dx * 6 * x
return result
# Wrap MyCube in a function so that it is clearer what the output is
def my_cube(x):
result, dx = MyCube.apply(x)
return result
注
backward
的输入,即grad_output
,也可以是追踪历史记录的张量。因此,如果backward
的实现使用了可微分操作(例如调用另一个自定义的Function
),高阶导数将正常工作。在这种情况下,通过save_for_backward
保存的张量也可以在反向传播中使用,并且它们会有梯度流回。然而,存储在ctx
中的张量不会有梯度流回。如果你希望某个存储在ctx
中的张量有梯度流回,则应将其作为自定义Function
的输出,并通过save_for_backward
进行保存。
你可能想检查一下你实现的反向方法是否实际计算了你的函数的导数。可以通过使用小有限差分与数值近似进行比较来验证这一点:
from torch.autograd import gradcheck
# gradcheck takes a tuple of tensors as input, check if your gradient
# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.
input = (torch.randn(20,20,dtype=torch.double,requires_grad=True), torch.randn(30,20,dtype=torch.double,requires_grad=True))
test = gradcheck(linear, input, eps=1e-6, atol=1e-4)
print(test)
有关有限差分梯度比较的更多详细信息,请参阅Numerical gradient checking。如果您的函数用于高阶导数(对反向传播进行求导),您可以使用相同包中的gradgradcheck
函数来检查高阶导数。
合并或独立的forward()
和setup_context()
有两种主要方法可以定义Function
。
我们推荐选择第二个选项(将forward()
和setup_context()
分开),因为这与PyTorch原生操作的实现方式更为接近,并且可以与torch.func
变换进行组合。不过,我们计划未来支持这两种方法;将forward()
与setup_context()
结合使用:能够带来更大的灵活性,因为你可以保存中间结果而无需将其作为输出返回。
请参阅上一节,了解如何定义具有单独的Function
和setup_context()
的方法。
这是一个定义一个Function
的示例,该函数结合了forward()
和setup_context()
:
class LinearFunction(Function):
@staticmethod
# ctx is the first argument to forward
def forward(ctx, input, weight, bias=None):
# The forward pass can use ctx.
ctx.save_for_backward(input, weight, bias)
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias
前向模式 AD
覆盖前向模式自动微分公式的方法具有非常相似的API,但有一些细微差别。你可以实现jvp()
函数。
它将根据输入的数量提供相同数量的张量
参数,每个参数表示相对于该输入的梯度。它应该返回与输出数量相同的张量,每个张量包含相对于其对应输出的梯度。在调用forward()
方法之后,立即调用jvp()
,但在apply()
返回之前。
jvp()
与backward()
函数有一些细微的差别:
-
您可以使用 ctx 传递任何数据从
forward()
到jvp()
函数。如果该状态不需要用于backward()
,您可以在jvp()
函数末尾通过del ctx.foo
显式释放它。 -
对于
jvp()
的实现,必须是可反向求导的,或者显式检查给定的前向模式梯度是否没有设置requires_grad
。 -
jvp()
函数必须与forward()
的视图/原地行为一致。例如,如果第i
个输入在原地被修改,则第i
个梯度必须在原地更新。类似地,如果第j
个输出是第k
个输入的视图。那么返回的第j
个输出梯度必须是给定的第k
个输入梯度的视图。 -
因为用户无法指定需要计算哪个梯度,所以
jvp()
函数应始终为所有输出计算梯度。 -
前向模式的梯度尊重由
set_materialize_grads()
设置的标志,并且在禁用此功能时,您可以获得为None
的输入梯度。
torch.func
转换和/或 torch.vmap()
请参见使用autograd.Function扩展torch.func以获取详细信息。
扩展torch.nn
(nn
](../nn.html#module-torch.nn) 暴露了两种类型的接口——模块及其功能版本。你可以通过这两种方式来扩展它,但我们建议在所有类型的层中使用模块(这些层包含任何参数或缓冲区),并推荐使用无参数的操作形式,如激活函数、池化等。
在上面的部分中,已经完全涵盖了添加操作的功能版本。
添加一个Module
由于nn
heavily 利用autograd
,添加一个新的Module
需要实现一个执行操作并能计算梯度的Function
。从现在开始,让我们假设我们希望实现一个Linear
模块,并且已经在上面的列表中实现了功能。添加这个只需要很少的代码。现在有两个函数需要实现:
-
__init__
(optional) - 接受诸如卷积核大小、特征数量等参数,并初始化模型的参数和缓存。
这是如何实现一个 Linear
模块的方法:
class Linear(nn.Module):
def __init__(self, input_features, output_features, bias=True):
super().__init__()
self.input_features = input_features
self.output_features = output_features
# nn.Parameter is a special kind of Tensor, that will get
# automatically registered as Module's parameter once it's assigned
# as an attribute. Parameters and buffers need to be registered, or
# they won't appear in .parameters() (doesn't apply to buffers), and
# won't be converted when e.g. .cuda() is called. You can use
# .register_buffer() to register buffers.
# nn.Parameters require gradients by default.
self.weight = nn.Parameter(torch.empty(output_features, input_features))
if bias:
self.bias = nn.Parameter(torch.empty(output_features))
else:
# You should always register all possible parameters, but the
# optional ones can be None if you want.
self.register_parameter('bias', None)
# Not a very smart way to initialize weights
nn.init.uniform_(self.weight, -0.1, 0.1)
if self.bias is not None:
nn.init.uniform_(self.bias, -0.1, 0.1)
def forward(self, input):
# See the autograd section for explanation of what happens here.
return LinearFunction.apply(input, self.weight, self.bias)
def extra_repr(self):
# (Optional)Set the extra information about this module. You can test
# it by printing an object of this class.
return 'input_features={}, output_features={}, bias={}'.format(
self.input_features, self.output_features, self.bias is not None
)
扩展torch
Python API
你可以通过定义一个自定义类并实现与 Tensor
相同的方法来创建自定义类型,以模拟 Tensor
。但是,如果你想让这些类型能够传递给像 torch.add()
这样的函数,这些函数位于顶级的 torch
命名空间中,并且接受 Tensor
操作数,该怎么办呢?
如果你的自定义 Python 类型定义了一个名为 __torch_function__
的方法,PyTorch 将在将你的自定义类实例传递给 torch
命名空间中的函数时调用你的 __torch_function__
实现。这样,你可以为 torch
命名空间中的任何函数定义自定义实现,而你的 __torch_function__
实现有权调用它们,从而让用户能够将其自定义类型与现有的 PyTorch 工作流程一起使用,这些工作流程已经为 Tensor
编写。这不仅适用于与 Tensor
无关的“鸭”类型,也适用于用户定义的 Tensor
子类。
扩展torch
以添加一个类似于Tensor
的类型
注
此功能的灵感来源于NumPy的__array_function__
协议。更多细节请参阅NumPy文档和NEP-0018。
为了具体说明这一点,让我们从一个简单的例子开始,展示API分发机制。我们将创建一个自定义类型,表示一个二维标量张量,由阶数N
和对角线元素的值value
参数化:
class ScalarTensor(object):
def __init__(self, N, value):
self._N = N
self._value = value
def __repr__(self):
return "ScalarTensor(N={}, value={})".format(self._N, self._value)
def tensor(self):
return self._value * torch.eye(self._N)
这个设计的第一版并没有太大的实用性。ScalarTensor
的主要功能是提供比基础张量类更简洁的标量张量字符串表示。
>>> d = ScalarTensor(5, 2)
>>> d
ScalarTensor(N=5, value=2)
>>> d.tensor()
tensor([[2., 0., 0., 0., 0.],
[0., 2., 0., 0., 0.],
[0., 0., 2., 0., 0.],
[0., 0., 0., 2., 0.],
[0., 0., 0., 0., 2.]])
如果尝试使用此对象与torch
API一起使用,我们将遇到问题:
>>> import torch
>>> torch.mean(d)
TypeError: mean(): argument 'input' (position 1) must be Tensor, not ScalarTensor
在 ScalarTensor
中添加 __torch_function__
实现,使得上述操作能够成功。让我们重新实现一下,这次添加一个 __torch_function__
方法:
HANDLED_FUNCTIONS = {}
class ScalarTensor(object):
def __init__(self, N, value):
self._N = N
self._value = value
def __repr__(self):
return "ScalarTensor(N={}, value={})".format(self._N, self._value)
def tensor(self):
return self._value * torch.eye(self._N)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func not in HANDLED_FUNCTIONS or not all(
issubclass(t, (torch.Tensor, ScalarTensor))
for t in types
):
return NotImplemented
return HANDLED_FUNCTIONS[func](*args, **kwargs)
__torch_function__
方法接受四个参数:func
,一个指向被重写的PyTorch API函数的引用;types
,实现__torch_function__
的张量类似类型的列表;args
,传递给函数的参数元组;以及kwargs
,传递给函数的关键字参数字典。它使用一个名为HANDLED_FUNCTIONS
的全局分发表来存储自定义实现。该字典的键是torch
命名空间中的函数,值是ScalarTensor
的实现。
注
使用全局分发表并不是 __torch_function__
API 的强制要求,它只是用于结构化您重写实现的一个有用的设计模式。
这个类定义还不足以让 torch.mean
在传递给它一个 ScalarTensor
时做出正确的行为——我们还需要为 ScalarTensor
操作数定义 torch.mean
的实现,并将其实现添加到 HANDLED_FUNCTIONS
分发表字典中。一种方法是定义一个装饰器:
import functools
def implements(torch_function):
"""Register a torch function override for ScalarTensor"""
def decorator(func):
functools.update_wrapper(func, torch_function)
HANDLED_FUNCTIONS[torch_function] = func
return func
return decorator
which可以应用于我们重写的实现中:
@implements(torch.mean)
def mean(input):
return float(input._value) / input._N
有了这个改动,我们现在可以使用 torch.mean
与 ScalarTensor
一起工作了:
>>> d = ScalarTensor(5, 2)
>>> torch.mean(d)
0.4
当然,torch.mean
是一个最简单的函数示例,因为它只需要一个操作数。我们也可以使用相同的机制来覆盖一个接受多个操作数的函数,其中任何一个可能是定义了 __torch_function__
的张量或类似张量的对象,例如对于 torch.add()
:
def ensure_tensor(data):
if isinstance(data, ScalarTensor):
return data.tensor()
return torch.as_tensor(data)
@implements(torch.add)
def add(input, other):
try:
if input._N == other._N:
return ScalarTensor(input._N, input._value + other._value)
else:
raise ValueError("Shape mismatch!")
except AttributeError:
return torch.add(ensure_tensor(input), ensure_tensor(other))
这个版本为当两个操作数都是ScalarTensor
实例时提供了一个快速路径,同时也提供了一个较慢的路径,在其中一个操作数不是ScalarTensor
时会退化为将数据转换为张量。这样可以确保重写的函数在任一操作数是ScalarTensor
或普通Tensor
时都能正确运行。
>>> s = ScalarTensor(2, 2)
>>> torch.add(s, s)
ScalarTensor(N=2, value=4)
>>> t = torch.tensor([[1, 1,], [1, 1]])
>>> torch.add(s, t)
tensor([[3., 1.],
[1., 3.]])
注意,我们的add
实现不接受alpha
或out
作为关键字参数,而像torch.add()
那样。
>>> torch.add(s, s, alpha=2)
TypeError: add() got an unexpected keyword argument 'alpha'
为了速度和灵活性,__torch_function__
分发机制不会检查重写函数的签名是否与torch
API中被重写的函数的签名匹配。对于某些应用程序,忽略可选参数可能是可以接受的,但为了确保与Tensor
完全兼容,用户实现的torch API函数应小心地精确模拟被重写函数的API。
在torch
API中,如果没有显式重载的函数将从__torch_function__
返回NotImplemented
。如果所有操作数在其上定义了__torch_function__
都返回NotImplemented
,PyTorch将引发一个TypeError
。这意味着大多数情况下,对于没有该类型显式重载的操作,当传递此类类型的实例时,会引发一个TypeError
。
>>> torch.mul(s, 3)
TypeError: no implementation found for 'torch.mul' on types that
implement __torch_function__: [ScalarTensor]
在实际操作中,这意味着如果你想通过实现一个__torch_function__
来实现你的自定义功能,你需要显式地实现完整的torch
API或与你的用例相关的整个子集。这可能是一个艰巨的任务,因为完整的torch
API相当庞大。
另一个选择是,对于未处理的操作不返回NotImplemented
,而是当没有重写时,将一个Tensor
传递给原始的torch
函数。例如,如果我们修改了ScalarTensor
的__torch_function__
实现为以下内容:
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func not in HANDLED_FUNCTIONS or not all(
issubclass(t, (torch.Tensor, ScalarTensor))
for t in types
):
args = [a.tensor() if hasattr(a, 'tensor') else a for a in args]
return func(*args, **kwargs)
return HANDLED_FUNCTIONS[func](*args, **kwargs)
然后torch.mul()
将正确工作,尽管返回类型始终是Tensor
而不是ScalarTensor
,即使两个操作数都是ScalarTensor
实例。
>>> s = ScalarTensor(2, 2)
>>> torch.mul(s, s)
tensor([[4., 0.],
[0., 4.]])
同时请参见下面的MetadataTensor
示例,这是该模式的另一种变体,但它始终返回一个MetadataTensor
,以便在使用torch
API时通过操作传播元数据。
__torch_function__
协议旨在实现API的全面覆盖,部分覆盖可能导致意外的结果,特别是某些函数会引发TypeError
。对于子类来说,这一点尤为重要,因为必须同时覆盖torch.add、torch.Tensor.__add__和torch.Tensor.add这三个功能,即使它们返回完全相同的结果。未能做到这一点还可能引发无限递归。如果需要在torch.Tensor
的子类中实现某个函数,他们必须在其实现内部使用super().__torch_function__
。
继承 torch.Tensor
自版本 1.7.0 起,torch.Tensor
上的方法以及公共 torch.*
命名空间中应用于 torch.Tensor
子类的函数将返回子类实例,而不是 torch.Tensor
实例。
>>> class SubTensor(torch.Tensor):
... pass
>>> type(torch.add(SubTensor([0]), SubTensor([1]))).__name__
'SubTensor'
>>> type(torch.add(SubTensor([0]), torch.tensor([1]))).__name__
'SubTensor'
如果有多个子类存在,默认情况下会选择层级结构中最低的一个。如果无法唯一确定这种情况,则会引发一个 TypeError
错误。
>>> type(torch.add(SubTensor2([0]), SubTensor([1]))).__name__
'SubTensor2'
>>> type(torch.add(SubTensor2([0]), torch.tensor([1]))).__name__
'SubTensor2'
>>> torch.add(SubTensor([0]), OtherSubTensor([1]))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: no implementation found for 'torch.add' on types that implement __torch_function__: [SubTensor, OtherSubTensor]
如果有人希望为所有张量方法提供一个全局覆盖,可以使用__torch_function__
。以下是一个记录所有函数/方法调用的示例:
class LoggingTensor(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
# NOTE: Logging calls Tensor.__repr__, so we can't log __repr__ without infinite recursion
if func is not torch.Tensor.__repr__:
logging.info(f"func: {func.__name__}, args: {args!r}, kwargs: {kwargs!r}")
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
然而,如果有人希望重写张量子类中的方法,他们可以这样做:直接重写该方法(通过为子类定义它),或者使用__torch_function__
并匹配func
。
在使用__torch_function__
时,子类应始终调用super().__torch_function__(func, ...)
而不是直接调用func
,因为这是在1.7.0版本之前的做法。如果忽略这一点,可能会导致func
递归调用__torch_function__
,从而引发无限递归。
扩展torch
添加一个张量
包装类型
另一个有用的案例是一种类型,它包装一个张量
,无论是作为属性还是通过子类。下面我们实现这种类型的特殊案例,即一个元数据张量
,它将一个元数据字典附加到张量
上,并通过torch
操作传播该元数据。由于这是一种对整个torch
API进行全面包装的通用类型,因此我们不需要逐个实现每个重写方法,这样可以使__torch_function__
实现对允许的操作更加宽松。
class MetadataTensor(object):
def __init__(self, data, metadata=None, **kwargs):
self._t = torch.as_tensor(data, **kwargs)
self._metadata = metadata
def __repr__(self):
return "Metadata:\n{}\n\ndata:\n{}".format(self._metadata, self._t)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
metadatas = tuple(a._metadata for a in args if hasattr(a, '_metadata'))
args = [getattr(a, '_t', a) for a in args]
assert len(metadatas) > 0
ret = func(*args, **kwargs)
return MetadataTensor(ret, metadata=metadatas[0])
这个简单的实现不一定能与torch
API中的每个函数配合使用,但它足以捕获大多数常见操作。
>>> metadata = {'owner': 'Ministry of Silly Walks'}
>>> m = MetadataTensor([[1, 2], [3, 4]], metadata=metadata)
>>> t = torch.tensor([[1, 2], [1, 2]])
>>> torch.add(t, m)
Metadata:
{'owner': 'Ministry of Silly Walks'}
data:
tensor([[2, 4],
[4, 6]])
>>> torch.mul(t, m)
Metadata:
{'owner': 'Ministry of Silly Walks'}
data:
tensor([[1, 4],
[3, 8]])
定义 __torch_function__
的多类型操作
可以使用具有多个不同类型的 torch API,每个类型都有一个 __torch_function__
实现,但必须特别小心。在这种情况下,规则是:
-
调度操作会为每个操作数收集所有不同的
__torch_function__
实现,并按顺序调用它们:子类在先,父类在后,其余的则按照运算符表达式从左到右依次调用。 -
如果有任何实现返回除了
NotImplemented
以外的值,则该值将作为结果返回。实现可以注册不支持某个操作,方法是返回NotImplemented
。 -
如果所有
__torch_function__
实现都返回NotImplemented
,PyTorch将引发一个TypeError
异常。
PyTorch API 的覆盖测试
实现 __torch_function__
的一个麻烦方面是,如果某些操作有重写而其他操作没有,则用户最多会看到不一致的体验,或者在使用未提供重写的函数时在运行时遇到错误。为了简化这一过程,PyTorch 提供了一个面向开发者的 API,以确保对 __torch_function__
重写的全面支持。此 API 是私有的,并且未来可能会在没有事先通知的情况下更改。
首先,要获取所有可覆盖函数的列表,请使用 torch.overrides._get_overridable_functions
。这将返回一个字典,其键是 PyTorch
Python API 中的命名空间,值是在该命名空间中可以被覆盖的函数列表。例如,让我们打印出 torch.nn.functional
中前5个可被覆盖的函数名称:
>>> from torch.overrides import get_overridable_functions
>>> func_dict = get_overridable_functions()
>>> nn_funcs = func_dict[torch.nn.functional]
>>> print([f.__name__ for f in nn_funcs[:5])
['adaptive_avg_pool1d', 'adaptive_avg_pool2d', 'adaptive_avg_pool3d',
'adaptive_max_pool1d', 'adaptive_max_pool1d_with_indices']
此函数列表使我们能够遍历所有可重写的函数,然而在实际操作中,仅凭这一点不足以编写针对这些函数的所有测试,因为需要繁琐且手动地为每个测试复制每个函数的签名。为了简化这一过程,torch.overrides._get_testing_overrides
函数返回一个字典,该字典将 PyTorch
API 中的可重写函数映射到具有与原函数相同签名但无条件返回 -1 的哑元 lambda 函数。这些函数主要用于配合 inspect
分析原始 PyTorch
函数的签名:
>>> import inspect
>>> from torch.overrides import get_testing_overrides
>>> override_dict = get_testing_overrides()
>>> dummy_add = override_dict[torch.add]
>>> inspect.signature(dummy_add)
<Signature (input, other, out=None)>
最后,torch.overrides.get_ignored_functions
返回一个元组,其中包含显式不能通过 __torch_function__
覆盖的函数。这个列表可以用来确认那些不在 get_overridable_functions
返回字典中的函数是无法被覆盖的。
扩展torch
原生API
虽然 __torch_function__
可以让用户有效地扩展 PyTorch 纯 Python 组件的行为,但它无法扩展 PyTorch 用 C++ 实现的部分。为此,Tensor
的子类还可以定义 __torch_dispatch__
,从而能够在 C++ 层面重写行为。
要有效利用此功能,了解PyTorch的本机实现方式至关重要。其中最重要的部分是所谓的“分发器”(关于它的最佳描述可见于这篇博客文章,尽管该内容略有过时)。正如其名称所暗示的那样,它负责为特定函数调用调用正确的后端函数。例如,当调用torch.add(a, b)
时,分发器将检查两个参数,确定应为此特定调用使用哪个“功能”(autograd、autocast、函数化等)和哪个“后端”(CPU、CUDA、MPS等),然后最终调用所有正确的内核。内核执行的一项非常常见的操作是“重新分发”。例如,在使用autocast在GPU上运行神经网络时,第一个调用将是autocast内核,它将处理任何可能的autocast逻辑并重新分发下去。接下来的功能将是autograd,它会正确创建autograd图,然后重新分发下去。最后,我们到达CUDA后端内核,它将启动正确的CUDA内核并返回最终结果。在返回过程中,autograd会将图形附加到输出上,而autocast则有机会在其退出时进行任何必要的更新。
分发器的配置之一是所有这些功能和后端键按什么顺序被调用。最新的列表及其顺序可以在 DispatchKey.h
中的 DispatchKey
枚举中找到。为了扩展 torch 的目的,与本讨论相关的排序的重要子集如下:
vmap -> 自动混合精度 -> 自动求导 -> 零张量 -> 负数/共轭 -> 功能化 -> Python -> 后端
在这个讨论中,最重要的关键点是 Python
,因为每个定义了 __torch_dispatch__
方法的张量子类都会调用此功能。用户定义的方法正是从那里被调用,并且行为可以在任意地方被重写。再次调用提供的 func
会执行“ redispatch ”。
这一实现的一些重要影响包括:
-
此代码在“所有功能之下”运行,因此它仅像常规后端一样负责生成每个张量的输出值(并且可以且应该忽略所有高级功能,如自动求导、自动混合精度等)。
-
如果任何高级功能实现了一个给定函数而没有重新分派,则该函数将永远不会到达“Python”键,因此
__torch_dispatch__
回调将永远不会被触发。这种情况特别适用于CompositeImplicitAutograd
函数,它们在Autograd级别通过分解为基本算子并直接评估这些算子而不是重新分派来实现。这是因为CompositeImplicitAutograd
函数通过隐式调用其他本机算子来指定其自动求导公式,因此在Autograd级别,该函数被分解为其本机算子,并且这些算子会被评估。 -
在回调Python并包装结果时,使用与常规PyTorch Python/C++绑定相同的转换。特别是,某些对象无法在Python中表示,需要特殊处理(例如未定义的张量会变成
None
)。 -
我们的本机函数作为
torch.ops.{namespace}.{func_name}.{overload_name}
延迟加载为可调用的Python对象,以实现与它们从Python交互的简便性。传递给__torch_dispatch__
的func
对象始终来自此命名空间。此命名空间可以直接用于调用本机算子并绕过常规的PyTorch Python API和绑定代码。
以类似的方式,__torch_function__
能够拦截 torch 的所有 Python API 和张量方法调用,而 __torch_dispatch__
则能够截获所有进入 aten 原生 API 的调用。注意,张量上的所有方法在进入调度器之前都会被转换为函数调用,因此它们将在这里显示为函数调用:torch.add(a, 2)
和 a + 2
将导致完全相同的 aten 调用。这些功能大多定义在 native_functions.yaml
中,该文件指定了这些函数的属性及其后端实现。然后通过代码生成自动注册它们以及指定的功能。一些更奇特的功能或特性也会在 C++ 代码库的其他地方或用户定义的 C++ 扩展中注册。
还可以使用torch.library
添加新的原生函数。这个Python功能允许定义和/或向原生函数添加新实现。这可以用于添加缺失的内核,替换现有的内核,或者定义全新的原生函数。
您可以在 subclass zoo 仓库中找到许多基于 __torch_dispatch__
的子类示例。
__torch_dispatch__
调用约定
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
pass
当用户调用带有__torch_dispatch__
输入的操作符时,该调用可能会转发到__torch_dispatch__
。在调用__torch_dispatch__
之前,args和kwargs会进行规范化,即:
-
kwargs
包含操作符架构中的仅关键字参数。如果某个关键字参数等于其在架构中的默认值,则不会传递该参数。 -
args
包含所有其他参数,无论它们是以位置还是关键字方式传递给操作符的。如果某个参数等于其默认值,并且它是最右边的位置参数,或者所有位于它右侧的参数均未传递,则该参数将不会被传递。
扩展所有torch
API与模式
很不幸,有些函数并不接受张量输入。这意味着上述描述的子类方法无法用于覆盖PyTorch所有函数的行为。此外,如果使用场景要求拦截每一个函数调用,则将每个张量改为子类可能会过于侵入性。
为了解决这个用例,我们引入了“模式”这一概念。这些模式适用于对__torch_function__
和__torch_dispatch__
的重写,并通过分别继承自torch.overrides.TorchFunctionMode
和torch.utils._python_dispatch.TorchDispatchMode
来创建,作为上下文管理器使用。
为了简化描述它如何与子类和其他模式交互,每当进入一个模式的上下文管理器时,每个函数都会表现得好像在参数列表开头有一个额外的张量参数,该参数以子类的形式传递了模式。这意味着特别地,所有模式处理程序将在任何子类处理程序之前被调用,并且与内部上下文管理器对应的模式始终首先运行。
需要注意的是,在特定模式处理器中,此模式会被禁用,但可以通过执行 with self:
手动重新启用。
这是一个显示每种类型日志模式的示例:
import torch
from torch.overrides import TorchFunctionMode, resolve_name
from torch.utils._python_dispatch import TorchDispatchMode
class FunctionLog(TorchFunctionMode):
def __torch_function__(self, func, types, args, kwargs=None):
print(f"Function Log: {resolve_name(func)}(*{args}, **{kwargs})")
return func(*args, **(kwargs or {}))
class DispatchLog(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args, kwargs=None):
print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
return func(*args, **(kwargs or {}))
def f():
a = torch.rand(10, requires_grad=True)
b = a * 2
b.sum().backward()
print("TorchFunctionMode logging:")
with FunctionLog():
f()
print("TorchDispatchMode logging:")
with DispatchLog():
f()
Which prints the following, with extra comments:
TorchFunctionMode logging:
Function Log: torch.rand(*(10,), **{'requires_grad': True})
Function Log: torch.Tensor.mul(*(tensor([0.7164, 0.9897, 0.1745, 0.9336, 0.4287, 0.7989, 0.2169, 0.7474, 0.5624,
0.5970], requires_grad=True), 2), **None)
Function Log: torch.Tensor.sum(*(tensor([1.4328, 1.9794, 0.3490, 1.8671, 0.8573, 1.5977, 0.4338, 1.4948, 1.1249,
1.1939], grad_fn=<MulBackward0>),), **None)
# Note that at the python level, we only see the call to backward but not what happens in the autograd engine.
Function Log: torch.Tensor.backward(*(tensor(12.3307, grad_fn=<SumBackward0>),), **{'gradient': None, 'retain_graph': None, 'create_graph': False, 'inputs': None})
TorchDispatchMode logging:
# Here the requires_grad flag from autograd is removed while default arguments were populated.
Dispatch Log: aten.rand.default(*([10],), **{'device': device(type='cpu'), 'pin_memory': False})
Dispatch Log: aten.mul.Tensor(*(tensor([0.2151, 0.6018, 0.8415, 0.9060, 0.2974, 0.7708, 0.6668, 0.0352, 0.7948,
0.6023], requires_grad=True), 2), **{})
Dispatch Log: aten.sum.default(*(tensor([0.4303, 1.2036, 1.6831, 1.8120, 0.5949, 1.5416, 1.3335, 0.0705, 1.5897,
1.2046], grad_fn=<MulBackward0>),), **{})
# Here we don't see the call to backward itself, but its constituents. Starting here with the factory function that creates the initial gradient.
Dispatch Log: aten.ones_like.default(*(tensor(11.4637, grad_fn=<SumBackward0>),), **{'pin_memory': False, 'memory_format': torch.preserve_format})
# This is the backward of the sum
Dispatch Log: aten.expand.default(*(tensor(1.), [10]), **{})
Dispatch Log: aten.mul.Tensor(*(tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), 2), **{})
Dispatch Log: aten.detach.default(*(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),), **{})
Dispatch Log: aten.detach.default(*(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),), **{})