PyTorch 入门指南
学习 PyTorch
图像和视频
音频
后端
强化学习
在生产环境中部署 PyTorch 模型
Profiling PyTorch
代码变换与FX
前端API
扩展 PyTorch
模型优化
并行和分布式训练
边缘端的 ExecuTorch
推荐系统
多模态

PyTorch 自定义运算符

PyTorch 提供了大量用于操作张量的运算符库(例如 torch.addtorch.sum 等)。然而,您可能希望为 PyTorch 引入一个新的自定义操作,并使其能够与 torch.compile、autograd 和 torch.vmap 等子系统协同工作。为此,您必须通过 Python 的 torch.library 文档 或 C++ 的 TORCH_LIBRARY API 将该自定义操作注册到 PyTorch 中。

从 Python 编写自定义运算符

请参阅 自定义 Python 运算符

如果满足以下情况,您可能希望使用 Python(而不是 C++)编写自定义运算符:

  • 您有一个 Python 函数,希望 PyTorch 将其视为一个不透明的可调用对象,尤其是在 torch.compiletorch.export 的上下文中。

  • 您有一些与 C++/CUDA 内核绑定的 Python 代码,并希望这些代码能够与 PyTorch 子系统(如 torch.compiletorch.autograd)结合使用。

  • 您正在使用 Python(而不是像 AOTInductor 这样的纯 C++ 环境)。

将自定义 C++ 和/或 CUDA 代码与 PyTorch 集成

请参阅 自定义 C++ 和 CUDA 运算符

如果您希望从 C++(而不是 Python)编写自定义运算符,可能是出于以下原因:

  • 您有自定义的 C++ 和/或 CUDA 代码。

  • 您计划将此代码与 AOTInductor 一起使用,以进行无 Python 的推理。

自定义操作符手册

有关教程和本页面中未涵盖的信息,请参阅 自定义操作符手册(我们正在努力将这些信息迁移到我们的文档站点)。我们建议您首先阅读上述教程之一,然后将自定义操作符手册作为参考;它并非从头到尾阅读的指南。

何时应该创建自定义操作符?

如果您的操作可以表示为内置 PyTorch 操作的组合,那么请将其编写为 Python 函数并调用它,而不是创建一个自定义操作。如果您正在调用 PyTorch 无法理解的某些库(例如,自定义 C/C++ 代码、自定义 CUDA 内核或 C/C++/CUDA 扩展的 Python 绑定),请使用操作注册 API 来创建自定义操作。

为什么要创建自定义操作符?

可以通过获取 Tensor 的数据指针并将其传递给使用 pybind 绑定的内核来使用 C/C++/CUDA 内核。然而,这种方法无法与 PyTorch 的子系统(如 autograd、torch.compile、vmap 等)协同工作。为了让一个操作能够与 PyTorch 子系统协同工作,必须通过操作符注册 API 进行注册。