torch.library

torch.library 是一组用于扩展 PyTorch 核心操作符库的 API。它提供了测试、创建和扩展自定义操作符的功能,包括使用 PyTorch 的 C++ 操作符注册 API(如 aten 操作符)定义的操作符。

要详细了解如何有效使用这些 API,请参阅PyTorch 自定义操作符首页

测试自定义操作

使用torch.library.opcheck()来检查自定义操作是否正确使用了Python的torch.library和/或C++的TORCH_LIBRARY API。此外,如果你的操作支持训练,请使用torch.autograd.gradcheck()来验证梯度计算是否数学上正确。

torch.library.opcheck(op, args, kwargs=None, *, test_utils=('test_schema', 'test_autograd_registration', 'test_faketensor', 'test_aot_dispatch_dynamic'), raise_exception=True)[源代码]

给定一个操作符和一些示例参数,测试该操作符是否已正确注册。

也就是说,当你使用 torch.library/TORCH_LIBRARY API 创建自定义操作时,你需要指定该操作的元数据(例如可变性信息),并确保传递给这些 API 的函数满足某些要求(例如在假/元/抽象内核中不访问数据指针)。opcheck 会验证这些元数据和属性。

具体来说,我们要测试以下内容:

  • test_schema: 如果模式与操作符的实现相匹配。例如,如果模式指定张量被修改了,那么我们会检查实现中确实进行了这样的修改。如果模式指定了返回一个新的张量,则我们检查实现是否真正返回了一个新的张量(而不是现有的一个或现有张量的一个视图)。

  • test_autograd_registration: 如果操作符支持训练(自动微分),我们会检查其自动微分公式是否通过 torch.library.register_autograd 或手动注册到一个或多个 DispatchKey::Autograd 关键。任何其他基于 DispatchKey 的注册可能会导致未定义的行为。

  • test_faketensor: 如果操作符有一个 FakeTensor 内核,并且该内核是正确的。FakeTensor 内核是操作符与 PyTorch 编译 API(如 torch.compile、export 和 FX)配合工作的必要条件,但不是充分条件。我们检查是否为操作符注册了 FakeTensor 内核(也称为元数据内核),并确保其正确性。此测试会比较在真实张量和 FakeTensor 上运行该操作符的结果,以验证它们具有相同的 Tensor 元数据(如大小、步幅、数据类型和设备等)。

  • test_aot_dispatch_dynamic: 检查操作符在 PyTorch 编译 API(torch.compile/export/FX)中的行为是否正确。此测试确保在 eager 模式下的 PyTorch 和 torch.compile 中的输出(以及适用的梯度)一致。此测试是 test_faketensor 的超集,并且是一个端到端测试;它还验证操作符是否支持函数化,以及反向传递(如果存在)是否也支持 FakeTensor 和函数化。

为了获得最佳结果,请多次调用opcheck,并使用一组具有代表性的输入。如果您的操作符支持自动微分,请将输入的requires_grad = True 传递给opcheck; 如果您的操作符支持多个设备(例如 CPU 和 CUDA),请在所有支持的设备上运行opcheck

参数
  • op (Union[OpOverload, OpOverloadPacket, CustomOpDef]) – 操作符。必须是使用 torch.library.custom_op() 装饰的函数,或者是在 torch.ops.* 中定义的 OpOverload 或 OpOverloadPacket(例如 torch.ops.aten.sin, torch.ops.mylib.foo)

  • args (Tuple[Any, ...]) – 操作符的参数列表

  • kwargs (Optional[Dict[str, Any]]) – 操作符的参数

  • test_utils (Union[str, Sequence[str]]) – 需要运行的测试用例。默认为所有用例。示例:("test_schema", "test_faketensor")

  • raise_exception (bool) – 是否在第一次出现错误时抛出异常。如果为 False,则返回一个包含每个测试结果的字典。

返回类型

Dict[str, str]

警告

opcheck 和 torch.autograd.gradcheck() 测试的内容不同:opcheck 用于检查你是否正确使用了 torch.library API,而 torch.autograd.gradcheck() 则用于验证你的自动微分公式在数学上是否准确。为了确保自定义操作符支持梯度计算,请同时使用这两种方法进行测试。

示例

>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
>>> def numpy_add(x: Tensor, y: float) -> Tensor:
>>>     x_np = x.numpy(force=True)
>>>     z_np = x_np + y
>>>     return torch.from_numpy(z_np).to(x.device)
>>>
>>> @numpy_sin.register_fake
>>> def _(x, y):
>>>     return torch.empty_like(x)
>>>
>>> def setup_context(ctx, inputs, output):
>>>     y, = inputs
>>>     ctx.y = y
>>>
>>> def backward(ctx, grad):
>>>     return grad * ctx.y, None
>>>
>>> numpy_sin.register_autograd(backward, setup_context=setup_context)
>>>
>>> sample_inputs = [
>>>     (torch.randn(3), 3.14),
>>>     (torch.randn(2, 3, device='cuda'), 2.718),
>>>     (torch.randn(1, 10, requires_grad=True), 1.234),
>>>     (torch.randn(64, 64, device='cuda', requires_grad=True), 90.18),
>>> ]
>>>
>>> for args in sample_inputs:
>>>     torch.library.opcheck(foo, args)

用Python创建新的自定义操作

使用torch.library.custom_op()来创建新的自定义操作。

torch.library.custom_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None)

将函数封装成自定义操作符。

你可能想要创建自定义操作的原因包括:- 将第三方库或自定义内核包装起来,以便与PyTorch子系统(如Autograd)一起使用。- 防止torch.compile/export/FX 跟踪窥探你的函数内部。

此 API 用作函数的装饰器(请参见示例)。提供的函数必须包含类型提示,因为这些提示是与 PyTorch 各个子系统进行交互所必需的。

参数
  • name (str) – 自定义操作的名称,格式为“{命名空间}::{名称}”,例如 “mylib::my_linear”。该名称在 PyTorch 子系统(如 torch.export 和 FX 图形)中用作操作的唯一标识符。为了避免名称冲突,请使用项目名称作为命名空间;例如,在 pytorch/fbgemm 中的所有自定义操作都使用“fbgemm”作为命名空间。

  • mutates_args (Iterable[str] 或 "unknown") – 函数修改的参数名称。必须确保这一点准确,否则行为将无法预测。如果设置为“unknown”,则假设所有输入都在被该操作符修改。

  • device_types (None | str | Sequence[str]) – 函数适用的设备类型。如果没有指定设备类型,则该函数将作为所有设备类型的默认实现使用。示例:“cpu”,“cuda”。当为不接受张量的操作符注册特定于设备的实现时,我们要求操作符具有一个“device: torch.device”参数。

  • schema (None | str) – 操作符的模式字符串。如果为 None(推荐),系统将根据操作符的类型注解自动推断其模式。除非有特殊原因,否则建议使用自动推断模式的方式。示例:“(Tensor x, int y) -> (Tensor, Tensor)”。

返回类型

Callable

注意

我们建议不要传入schema参数,而是让系统根据类型注解来推断它。自行编写模式容易出错。如果你认为我们的解释不符合你的需求,可以提供自己的模式。有关如何编写模式字符串的更多信息,请参见此处

示例:
>>> import torch
>>> from torch import Tensor
>>> from torch.library import custom_op
>>> import numpy as np
>>>
>>> @custom_op("mylib::numpy_sin", mutates_args=())
>>> def numpy_sin(x: Tensor) -> Tensor:
>>>     x_np = x.cpu().numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np).to(device=x.device)
>>>
>>> x = torch.randn(3)
>>> y = numpy_sin(x)
>>> assert torch.allclose(y, x.sin())
>>>
>>> # Example of a custom op that only works for one device type.
>>> @custom_op("mylib::numpy_sin_cpu", mutates_args=(), device_types="cpu")
>>> def numpy_sin_cpu(x: Tensor) -> Tensor:
>>>     x_np = x.numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np)
>>>
>>> x = torch.randn(3)
>>> y = numpy_sin_cpu(x)
>>> assert torch.allclose(y, x.sin())
>>>
>>> # Example of a custom op that mutates an input
>>> @custom_op("mylib::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu")
>>> def numpy_sin_inplace(x: Tensor) -> None:
>>>     x_np = x.numpy()
>>>     np.sin(x_np, out=x_np)
>>>
>>> x = torch.randn(3)
>>> expected = x.sin()
>>> numpy_sin_inplace(x)
>>> assert torch.allclose(x, expected)
>>>
>>> # Example of a factory function
>>> @torch.library.custom_op("mylib::bar", mutates_args={}, device_types="cpu")
>>> def bar(device: torch.device) -> Tensor:
>>>     return torch.ones(3)
>>>
>>> bar("cpu")

扩展自定义操作(使用Python或C++创建)

使用 register.* 方法,例如 torch.library.register_kernel() 和函数 torch.library.register_fake,为任何操作符添加实现(这些操作符可能是通过 torch.library.custom_op() 或 PyTorch 的 C++ 操作符注册 API 创建的)。

torch.library.register_kernel(op, device_types, func=None, /, *, lib=None)[源代码]

为该操作符的设备类型注册其实现。

有效的设备类型有:“cpu”,“cuda”,“xla”,“mps”,“ipu”,“xpu”等。此 API 可用作装饰器。

参数
  • fn (Callable) – 用于注册为给定设备类型实现的函数。

  • device_types (None | str | Sequence[str]) – 要注册实现的设备类型。如果为 None,则会注册到所有设备类型 - 请仅在你的实现确实与设备类型无关时使用此选项。

示例:
>>> import torch
>>> from torch import Tensor
>>> from torch.library import custom_op
>>> import numpy as np
>>>
>>> # Create a custom op that works on cpu
>>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu")
>>> def numpy_sin(x: Tensor) -> Tensor:
>>>     x_np = x.numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np)
>>>
>>> # Add implementations for the cuda device
>>> @torch.library.register_kernel("mylib::numpy_sin", "cuda")
>>> def _(x):
>>>     x_np = x.cpu().numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np).to(device=x.device)
>>>
>>> x_cpu = torch.randn(3)
>>> x_cuda = x_cpu.cuda()
>>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin())
>>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin())
torch.library.register_autograd(op, backward, /, *, setup_context=None, lib=None)[源代码]

为这个自定义操作注册一个反向公式。

为了让操作符与自动微分(autograd)协同工作,你需要注册一个反向传播公式:1. 提供“backward”函数以告知我们如何在反向传递期间计算梯度。2. 如果需要从正向传递中获取任何值来计算梯度,你可以使用setup_context保存这些用于反向传递的值。

backward 在反向传递期间运行。它接受(ctx, *grads): - grads 是一个或多个梯度,其数量与操作符的输出数量相匹配。ctx 对象是用于torch.autograd.Function 的同一个 ctx 对象。backward_fn 的语义与torch.autograd.Function.backward() 相同。

setup_context(ctx, inputs, output) 在前向传递期间运行。请通过调用torch.autograd.function.FunctionCtx.save_for_backward() 或将其作为 ctx 对象的属性来保存反向传递所需的量。如果你的自定义操作有仅限关键字参数,我们期望 setup_context 的签名是setup_context(ctx, inputs, keyword_only_inputs, output)

_BOTH_ setup_context_fnbackward_fn 必须是可以追踪的。也就是说,它们不能直接访问torch.Tensor.data_ptr(),并且不能依赖于或修改全局状态。如果你需要一个不可追踪的反向传播函数,可以将其作为单独的操作,在backward_fn中调用。

示例

>>> import torch
>>> import numpy as np
>>> from torch import Tensor
>>>
>>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=())
>>> def numpy_sin(x: Tensor) -> Tensor:
>>>     x_np = x.cpu().numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np).to(device=x.device)
>>>
>>> def setup_context(ctx, inputs, output) -> Tensor:
>>>     x, = inputs
>>>     ctx.save_for_backward(x)
>>>
>>> def backward(ctx, grad):
>>>     x, = ctx.saved_tensors
>>>     return grad * x.cos()
>>>
>>> torch.library.register_autograd(
...     "mylib::numpy_sin", backward, setup_context=setup_context
... )
>>>
>>> x = torch.randn(3, requires_grad=True)
>>> y = numpy_sin(x)
>>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
>>> assert torch.allclose(grad_x, x.cos())
>>>
>>> # Example with a keyword-only arg
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
>>> def numpy_mul(x: Tensor, *, val: float) -> Tensor:
>>>     x_np = x.cpu().numpy()
>>>     y_np = x_np * val
>>>     return torch.from_numpy(y_np).to(device=x.device)
>>>
>>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor:
>>>     ctx.val = keyword_only_inputs["val"]
>>>
>>> def backward(ctx, grad):
>>>     return grad * ctx.val
>>>
>>> torch.library.register_autograd(
...     "mylib::numpy_mul", backward, setup_context=setup_context
... )
>>>
>>> x = torch.randn(3, requires_grad=True)
>>> y = numpy_mul(x, val=3.14)
>>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
>>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
torch.library.register_fake(op, func=None, /, *, lib=None, _stacklevel=1)[源代码]

为该操作符注册一个伪张量实现(“假实现”)。

有时也称为“元内核”或“抽象实现”。

“FakeTensor 实现”定义了此操作符在不携带数据的张量(“FakeTensor”)上的行为。给定一些具有特定属性(如大小、步幅、存储偏移和设备)的输入张量,它指定了输出张量的相应属性。

FakeTensor 的实现与操作符具有相同的签名,并且既适用于 FakeTensors 也适用于元张量。编写一个 FakeTensor 实现时,假设所有传递给操作符的 Tensor 输入都是常规的 CPU/CUDA/Meta 张量,但它们没有实际存储数据。你需要返回常规的 CPU/CUDA/Meta 张量作为输出。此外,FakeTensor 的实现必须仅由 PyTorch 操作组成,并且不能直接访问任何输入或中间张量的实际存储或数据。

此 API 可以用作装饰器(参见示例)。

关于自定义操作的详细指南,请参见 https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html

示例

>>> import torch
>>> import numpy as np
>>> from torch import Tensor
>>>
>>> # Example 1: an operator without data-dependent output shape
>>> @torch.library.custom_op("mylib::custom_linear", mutates_args=())
>>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
>>>     raise NotImplementedError("Implementation goes here")
>>>
>>> @torch.library.register_fake("mylib::custom_linear")
>>> def _(x, weight, bias):
>>>     assert x.dim() == 2
>>>     assert weight.dim() == 2
>>>     assert bias.dim() == 1
>>>     assert x.shape[1] == weight.shape[1]
>>>     assert weight.shape[0] == bias.shape[0]
>>>     assert x.device == weight.device
>>>
>>>     return (x @ weight.t()) + bias
>>>
>>> with torch._subclasses.fake_tensor.FakeTensorMode():
>>>     x = torch.randn(2, 3)
>>>     w = torch.randn(3, 3)
>>>     b = torch.randn(3)
>>>     y = torch.ops.mylib.custom_linear(x, w, b)
>>>
>>> assert y.shape == (2, 3)
>>>
>>> # Example 2: an operator with data-dependent output shape
>>> @torch.library.custom_op("mylib::custom_nonzero", mutates_args=())
>>> def custom_nonzero(x: Tensor) -> Tensor:
>>>     x_np = x.numpy(force=True)
>>>     res = np.stack(np.nonzero(x_np), axis=1)
>>>     return torch.tensor(res, device=x.device)
>>>
>>> @torch.library.register_fake("mylib::custom_nonzero")
>>> def _(x):
>>> # Number of nonzero-elements is data-dependent.
>>> # Since we cannot peek at the data in an fake impl,
>>> # we use the ctx object to construct a new symint that
>>> # represents the data-dependent size.
>>>     ctx = torch.library.get_ctx()
>>>     nnz = ctx.new_dynamic_size()
>>>     shape = [nnz, x.dim()]
>>>     result = x.new_empty(shape, dtype=torch.int64)
>>>     return result
>>>
>>> from torch.fx.experimental.proxy_tensor import make_fx
>>>
>>> x = torch.tensor([0, 1, 2, 3, 4, 0])
>>> trace = make_fx(torch.ops.mylib.custom_nonzero, tracing_mode="symbolic")(x)
>>> trace.print_readable()
>>>
>>> assert torch.allclose(trace(x), torch.ops.mylib.custom_nonzero(x))
torch.library.register_vmap(op, func=None, /, *, lib=None)[源代码]

为这个自定义操作注册一个 vmap 实现,以便支持 torch.vmap()

此 API 可以用作装饰器(参见示例)。

为了使一个操作符与torch.vmap()协同工作,你需要按照以下签名注册一个vmap实现:

vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs),

其中 *args**kwargs 分别是 op 的位置参数和关键字参数。我们不支持仅限关键字的张量参数。

它指定了如何根据增加了一个维度的输入(由in_dims指定)来计算op的批量版本。

对于每个 args 中的参数,in_dims 对应一个 Optional[int]。如果该参数不是张量或不被 vmapped 处理,则其值为 None;否则,它是一个整数,指定要对张量的哪个维度进行 vmapping。

info 包含一些可能有用的附加元数据:info.batch_size 指定了被 vmapped 的维度的大小,而 info.randomness 则是传递给 torch.vmap()randomness 选项。

函数func的返回值是一个元组(output, out_dims)。类似于in_dimsout_dims应该与output具有相同的结构,并且每个输出包含一个out_dim,以指定该输出是否包含vmapped维度及其索引。

示例

>>> import torch
>>> import numpy as np
>>> from torch import Tensor
>>> from typing import Tuple
>>>
>>> def to_numpy(tensor):
>>>     return tensor.cpu().numpy()
>>>
>>> lib = torch.library.Library("mylib", "FRAGMENT")
>>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=())
>>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]:
>>>     x_np = to_numpy(x)
>>>     dx = torch.tensor(3 * x_np ** 2, device=x.device)
>>>     return torch.tensor(x_np ** 3, device=x.device), dx
>>>
>>> def numpy_cube_vmap(info, in_dims, x):
>>>     result = numpy_cube(x)
>>>     return result, (in_dims[0], in_dims[0])
>>>
>>> torch.library.register_vmap(numpy_cube, numpy_cube_vmap)
>>>
>>> x = torch.randn(3)
>>> torch.vmap(numpy_cube)(x)
>>>
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
>>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
>>>     return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
>>>
>>> @torch.library.register_vmap("mylib::numpy_mul")
>>> def numpy_mul_vmap(info, in_dims, x, y):
>>>     x_bdim, y_bdim = in_dims
>>>     x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
>>>     y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
>>>     result = x * y
>>>     result = result.movedim(-1, 0)
>>>     return result, 0
>>>
>>>
>>> x = torch.randn(3)
>>> y = torch.randn(3)
>>> torch.vmap(numpy_mul)(x, y)

注意

vmap 函数应该保持整个自定义操作符的语义不变。也就是说,grad(vmap(op)) 可以用 grad(map(op)) 来替换。

如果你的自定义操作符在反向传递中具有任何自定义行为,请记住这一点。

torch.library.impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1)[源代码]

此 API 在 PyTorch 2.4 中更名为 torch.library.register_fake()。请使用新的名称。

torch.library.get_ctx()[源代码]

get_ctx() 返回当前的 AbstractImplCtx 对象。

调用 get_ctx() 只能在假实现中有效(请参阅 torch.library.register_fake() 以获取更多使用详情)。

返回类型

FakeImplCtx

torch.library.register_torch_dispatch(op, torch_dispatch_class, func=None, /, *, lib=None)[源代码]

为给定的操作符和 torch_dispatch_class 注册一条 torch_dispatch 规则。

这允许通过开放注册来指定操作符与torch_dispatch_class之间的行为,而无需直接修改torch_dispatch_class或操作符本身。

torch_dispatch_class 可以是一个具有 __torch_dispatch__ 方法的 Tensor 子类,或者是一个 TorchDispatchMode。

如果它是 Tensor 的子类,我们期望 func 具有以下签名:(cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any

如果它是 TorchDispatchMode,我们期望 func 具有以下签名:(mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any

argskwargs 将会被规范化,与在 __torch_dispatch__ 中相同(参见__torch_dispatch__ 调用约定)。

示例

>>> import torch
>>>
>>> @torch.library.custom_op("mylib::foo", mutates_args={})
>>> def foo(x: torch.Tensor) -> torch.Tensor:
>>>     return x.clone()
>>>
>>> class MyMode(torch.utils._python_dispatch.TorchDispatchMode):
>>>     def __torch_dispatch__(self, func, types, args=(), kwargs=None):
>>>         return func(*args, **kwargs)
>>>
>>> @torch.library.register_torch_dispatch("mylib::foo", MyMode)
>>> def _(mode, func, types, args, kwargs):
>>>     x, = args
>>>     return x + 1
>>>
>>> x = torch.randn(3)
>>> y = foo(x)
>>> assert torch.allclose(y, x)
>>>
>>> with MyMode():
>>>     y = foo(x)
>>> assert torch.allclose(y, x + 1)
torch.library.infer_schema(prototype_function, /, *, mutates_args, op_name=None)

解析具有类型提示的给定函数的模式。模式是从函数的类型提示中推断出来的,并可以用于定义新的操作符。

我们作出如下假设:

  • 没有任何输出与其他输入或输出重复。

  • 未指定库的字符串类型注解“device、dtype、Tensor、types”
    假设为 torch.* 类型。同样,字符串类型的注解有 “Optional, List, Sequence, Union” 等。
    未指定库的类型假设为 typing.*。
  • 只有在 mutates_args 中列出的参数会被修改。如果 mutates_args 的值为“未知”,
    它假设该操作符的所有输入都正在被修改。

调用者(如自定义操作API)负责验证这些假设。

参数
  • prototype_function (Callable) – 用于根据其类型注释推断模式的函数。

  • op_name (可选[str]) – 操作符在模式中的名称。如果 name 为 None,则该名称不会包含在推断出的模式中。需要注意的是,torch.library.Library.define 的输入模式需要指定一个操作符名称。

  • mutates_args ("unknown"|Iterable[str]) – 表示函数中被修改的参数。

返回值

推导出来的模式。

返回类型

str

示例

>>> def foo_impl(x: torch.Tensor) -> torch.Tensor:
>>>     return x.sin()
>>>
>>> infer_schema(foo_impl, op_name="foo", mutates_args={})
foo(Tensor x) -> Tensor
>>>
>>> infer_schema(foo_impl, mutates_args={})
(Tensor x) -> Tensor
classtorch._library.custom_ops.CustomOpDef(namespace, name, schema, fn)[源代码]

CustomOpDef 是一个围绕函数的封装器,用于将其转换为自定义操作。

它提供了多种方法来为这个自定义操作添加额外的功能。

不要直接实例化 CustomOpDef;相反,应该使用 torch.library.custom_op() API。

set_kernel_enabled(device_type, enabled=True)[源代码]

禁用或重新启用已为此自定义操作符注册的内核。

如果内核已处于禁用或启用状态,则不会执行任何操作。

注意

如果一个内核先被禁用再被注册,它会一直保持禁用状态,直到重新启用为止。

参数
  • device_type (str) – 设备类型的名称,用于禁用或启用内核。

  • disable (bool) – 控制是否禁用或启用内核。

示例

>>> inp = torch.randn(1)
>>>
>>> # define custom op `f`.
>>> @custom_op("mylib::f", mutates_args=())
>>> def f(x: Tensor) -> Tensor:
>>>     return torch.zeros(1)
>>>
>>> print(f(inp))  # tensor([0.]), default kernel
>>>
>>> @f.register_kernel("cpu")
>>> def _(x):
>>>     return torch.ones(1)
>>>
>>> print(f(inp))  # tensor([1.]), CPU kernel
>>>
>>> # temporarily disable the CPU kernel
>>> with f.set_kernel_enabled("cpu", enabled = False):
>>>     print(f(inp))  # tensor([0.]) with CPU kernel disabled

低级别 API

以下 API 直接绑定到 PyTorch 的 C++ 低级别操作注册 API。

警告

低级别的操作注册API和PyTorch调度器是复杂的PyTorch概念。我们建议您尽可能使用上述高级别API(无需使用torch.library.Library对象)。这篇博客文章 <http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/> 是了解PyTorch调度器的一个很好的起点。

你可以在 Google Colab 上找到一个教程,该教程通过示例展示了如何使用此 API。

classtorch.library.Library(ns, kind, dispatch_key='')[源代码]

这是一个用于从Python创建和使用库的类,这些库可以用来注册新的操作符或替换现有库中的操作符。用户还可以选择传递一个分发键名,以便只为特定的操作符注册内核。

要覆盖现有库(名称为 ns)中的操作符,创建一个新库并将类型设为“IMPL”。若要注册新的操作符,则创建一个新的库(名称为 ns)并将类型设为“DEF”。为了绕过每个命名空间只能有一个库的限制,并在已存在的库中注册操作符,请创建一个片段并将类型设为“FRAGMENT”。

参数
  • ns - 库的名字

  • kind – “DEF”,“IMPL”(默认为 “IMPL”),“FRAGMENT”

  • dispatch_key - PyTorch 的分发键(默认值:"")

define(schema, alias_analysis='', *, tags=())[源代码]

在ns命名空间中定义一个新的运算符及其含义。

参数
  • schema - 用于定义新操作符的功能模式。

  • alias_analysis (可选) - 表示操作符参数的别名属性是否可以从模式中推断(默认行为),否则为“CONSERVATIVE”。

  • tags (Tag|Sequence[Tag]) – 一个或多个 torch.Tag 应用于此操作符。标记操作符会改变其在各种 PyTorch 子系统中的行为;请在应用之前仔细阅读 torch.Tag 的文档。

返回值

根据模式推断出的操作符名称。

示例:
>>> my_lib = Library("mylib", "DEF")
>>> my_lib.define("sum(Tensor self) -> Tensor")
fallback(fn, dispatch_key='', *, with_keyset=False)[源代码]

将函数实现注册为给定键的备用选项。

此函数仅适用于拥有全局命名空间(“_”)的库。

参数
  • fn - 作为给定调度键的备用函数,或者用于注册备用选项的 fallthrough_kernel()

  • dispatch_key — 分发键,输入函数应据此进行注册。默认情况下,它使用创建库时所用的分发键。

  • with_keyset - 控制当前调度器调用的键集是否应作为第一个参数传递给 fn 函数。这有助于为重调度调用创建合适的键集。

示例:
>>> my_lib = Library("_", "IMPL")
>>> def fallback_kernel(op, *args, **kwargs):
>>>     # Handle all autocast ops generically
>>>     # ...
>>> my_lib.fallback(fallback_kernel, "Autocast")
impl(op_name, fn, dispatch_key='', *, with_keyset=False)[源代码]

为库中定义的操作注册函数实现。

参数
  • op_name - 操作符的名称(包括重载)或 OpOverload 对象。

  • fn — 输入分发键的操作实现函数,或使用fallthrough_kernel()来注册一个备用操作。

  • dispatch_key — 分发键,输入函数应据此进行注册。默认情况下,它使用创建库时所用的分发键。

  • with_keyset - 控制当前调度器调用的键集是否应作为第一个参数传递给 fn 函数。这有助于为重调度调用创建合适的键集。

示例:
>>> my_lib = Library("aten", "IMPL")
>>> def div_cpu(self, other):
>>>     return self * (1 / other)
>>> my_lib.impl("div.Tensor", div_cpu, "CPU")
torch.library.fallthrough_kernel()[源代码]

一个传递给Library.impl的伪函数,用于注册贯穿。

torch.library.define(qualname, schema, *, lib=None, tags=())[源代码]
torch.library.define(lib, schema, alias_analysis='')

定义一个新的运算符。

在 PyTorch 中,定义一个操作符(op)是一个两步过程:1. 定义操作符的名称和模式;2. 实现该操作符与各种 PyTorch 子系统(如 CPU/CUDA 张量、Autograd 等)交互的行为。

这个入口定义了自定义操作符(第一步),然后你需要通过调用各种impl_* API 来完成第二步,例如torch.library.impl() 或者 torch.library.register_fake()

参数
  • qualname (str) – 操作符的限定名。应类似于“命名空间::名称”的字符串,例如“aten::sin”。在 PyTorch 中,每个操作符需要一个唯一的命名空间以避免名称冲突;给定的操作符只能创建一次。如果你正在编写 Python 库,我们建议使用顶级模块的名称作为命名空间。

  • schema (str) – 操作符的模式。例如,对于接受一个 Tensor 并返回一个 Tensor 的操作符,“(Tensor x) -> Tensor”。它不包含操作符名称(在qualname中传递)。

  • lib (Optional[Library]) – 如果提供,此操作符的生命周期将与 Library 对象的生命周期绑定。

  • tags (Tag|Sequence[Tag]) – 一个或多个 torch.Tag 应用于此操作符。标记操作符会改变其在各种 PyTorch 子系统中的行为;请在应用之前仔细阅读 torch.Tag 的文档。

示例:
>>> import torch
>>> import numpy as np
>>>
>>> # Define the operator
>>> torch.library.define("mylib::sin", "(Tensor x) -> Tensor")
>>>
>>> # Add implementations for the operator
>>> @torch.library.impl("mylib::sin", "cpu")
>>> def f(x):
>>>     return torch.from_numpy(np.sin(x.numpy()))
>>>
>>> # Call the new operator from torch.ops.
>>> x = torch.randn(3)
>>> y = torch.ops.mylib.sin(x)
>>> assert torch.allclose(y, x.sin())
torch.library.impl(qualname, types, func=None, *, lib=None)[源代码]
torch.library.impl(lib, name, dispatch_key='')

为该操作符的设备类型注册其实现。

你可以将“default”传递给 types 以将此实现设置为所有设备类型的默认实现。请仅在该实现确实支持所有设备类型时使用此选项;例如,如果它是内置 PyTorch 操作符的组合,则可以使用。

有效的一些类型包括:“cpu”,“cuda”,“xla”,“mps”,“ipu”,“xpu”。

参数
  • qualname (str) – 应为类似“命名空间::操作符名称”的字符串。

  • types (str|Sequence[str]) – 需要注册实现的设备类型。

  • lib (Optional[Library]) – 如果提供,此注册的生命周期将与 Library 对象的生命周期绑定。

示例

>>> import torch
>>> import numpy as np
>>>
>>> # Define the operator
>>> torch.library.define("mylib::mysin", "(Tensor x) -> Tensor")
>>>
>>> # Add implementations for the cpu device
>>> @torch.library.impl("mylib::mysin", "cpu")
>>> def f(x):
>>>     return torch.from_numpy(np.sin(x.numpy()))
>>>
>>> x = torch.randn(3)
>>> y = torch.ops.mylib.mysin(x)
>>> assert torch.allclose(y, x.sin())