PyTorch 自定义运算符
PyTorch 提供了大量用于操作张量的运算符库(例如 torch.add
、torch.sum
等)。然而,您可能希望为 PyTorch 引入一个新的自定义操作,并使其能够与 torch.compile
、autograd 和 torch.vmap
等子系统协同工作。为此,您必须通过 Python 的 torch.library 文档 或 C++ 的 TORCH_LIBRARY
API 将该自定义操作注册到 PyTorch 中。
从 Python 编写自定义运算符
请参阅 自定义 Python 运算符。
如果满足以下情况,您可能希望使用 Python(而不是 C++)编写自定义运算符:
-
您有一个 Python 函数,希望 PyTorch 将其视为一个不透明的可调用对象,尤其是在
torch.compile
和torch.export
的上下文中。 -
您有一些与 C++/CUDA 内核绑定的 Python 代码,并希望这些代码能够与 PyTorch 子系统(如
torch.compile
或torch.autograd
)结合使用。 -
您正在使用 Python(而不是像 AOTInductor 这样的纯 C++ 环境)。
将自定义 C++ 和/或 CUDA 代码与 PyTorch 集成
请参阅 自定义 C++ 和 CUDA 运算符。
如果您希望从 C++(而不是 Python)编写自定义运算符,可能是出于以下原因:
-
您有自定义的 C++ 和/或 CUDA 代码。
-
您计划将此代码与
AOTInductor
一起使用,以进行无 Python 的推理。
自定义操作符手册
有关教程和本页面中未涵盖的信息,请参阅 自定义操作符手册(我们正在努力将这些信息迁移到我们的文档站点)。我们建议您首先阅读上述教程之一,然后将自定义操作符手册作为参考;它并非从头到尾阅读的指南。
何时应该创建自定义操作符?
如果您的操作可以表示为内置 PyTorch 操作的组合,那么请将其编写为 Python 函数并调用它,而不是创建一个自定义操作。如果您正在调用 PyTorch 无法理解的某些库(例如,自定义 C/C++ 代码、自定义 CUDA 内核或 C/C++/CUDA 扩展的 Python 绑定),请使用操作注册 API 来创建自定义操作。
为什么要创建自定义操作符?
可以通过获取 Tensor 的数据指针并将其传递给使用 pybind 绑定的内核来使用 C/C++/CUDA 内核。然而,这种方法无法与 PyTorch 的子系统(如 autograd、torch.compile、vmap 等)协同工作。为了让一个操作能够与 PyTorch 子系统协同工作,必须通过操作符注册 API 进行注册。