通过 PrivateUse1 简化新后端的集成
在本教程中,我们将逐步介绍如何通过 PrivateUse1
集成一个位于 pytorch/pytorch
仓库外部的新后端。请注意,本教程假设您已经具备 PyTorch 的基础知识,并且是 PyTorch 的高级用户。
本教程仅涉及与PrivateUse1机制相关的部分,该机制有助于集成新设备,其他部分将不予介绍。同时,本教程涉及的所有模块并非都是必需的,您可以根据实际需求选择对您有帮助的模块。
什么是 PrivateUse1?
在 PyTorch 2.0 之前,PyTorch 提供了三个保留的调度键(及其对应的 Autograd 键)用于原型化树外后端扩展,这三个调度键如下:
-
PrivateUse1/AutogradPrivateUse1
-
PrivateUse2/AutogradPrivateUse2
-
PrivateUse3/AutogradPrivateUse3
在原型验证通过后,您可以为新的后端(如 CUDA、XLA、MPS 等)申请私钥。
然而,随着 PyTorch 的快速发展,越来越多的硬件厂商试图将其后端集成到 PyTorch 中,这可能会导致以下问题:
-
每个新的后端集成都需要大量的文件修改
-
目前 Dispatch Keys 的数量存在硬性限制(
DispatchKeySet
64 位限制)
此外,通过 PrivateUse1 键将新后端集成到 PyTorch 中也存在问题,因为无法同时集成多个后端。幸运的是,这些树外后端很少会同时使用。
鉴于上述原因,社区开始建议通过 PrivateUse1
将新的后端集成到 PyTorch 中。
然而,之前的 PrivateUse1
机制并不能完全满足新后端的集成需求,因为它在某些模块中缺乏相关的支持,例如 Storage、AMP、Distributed 等。
随着 PyTorch 2.1.0 的发布,针对 PrivateUse1
在新后端集成方面进行了一系列优化和增强,现在能够快速高效地支持新设备的集成。
如何通过 PrivateUse1 集成新后端
在本节中,我们将讨论通过 PrivateUse1
将新后端集成到 PyTorch 中的细节,主要包括以下几个部分:
-
为新后端注册内核。
-
为新后端注册生成器。
-
为新后端注册设备防护。
-
为新后端元数据注册序列化和反序列化函数。
-
其他模块。
为新后端注册内核
新的后端可能有一些高性能的算子实现,可以通过在C++中注册一个分派算子中描述的TORCH_LIBRARY_IMPL
API注册到调度器中。这涉及以下几种情况:
- 将所有新后端支持的算子注册到调度器中,同时注册回退机制,以便当新后端不支持某些算子时,这些算子可以回退到 CPU 上执行,以确保功能的可用性。
- 如果需要新后端覆盖
PyTorch Autograd 层
,可以通过AutogradPrivateUse1
将torch::autograd::Function
的内核注册到调度器,调度器和自动求导系统将自动调用这些运算符的前向和反向实现。
- 注册希望支持自动混合精度(AMP)和回退机制的内核到调度器,通过
AutocastPrivateUse1
,自动转换系统将在需要时自动调用这些内核。
需要补充的是,如果您想在一个新的后端中支持AMP,您需要通过torch._register_device_module("backend_name", BackendModule)
注册一个新的BackendModule
,并且BackendModule
需要包含以下API:
-
get_amp_supported_dtype() -> List[torch.dtype]
获取新后端在 AMP 中支持的数据类型,可能会多支持一种
dtype
。 -
is_autocast_enabled() -> bool
检查新后端是否启用了 AMP。
-
get_autocast_dtype() -> torch.dtype
获取新后端在 AMP 中支持的数据类型,该数据类型由
set_autocast_dtype
设置或使用默认的dtype
,默认的dtype
是torch.float16
。 -
set_autocast_enabled(bool) -> None
在新后端上启用或禁用 AMP。
-
set_autocast_dtype(dtype) -> None
设置新后端在 AMP 中支持的数据类型,该数据类型必须包含在
get_amp_supported_dtype
获取的dtypes
中。
为新后端注册生成器
需要支持与新设备对应的生成器。目前,PrivateUse1
可以动态注册自定义生成器,主要分为以下几个步骤。
-
继承
GeneratorImpl
类以实现与新后端对应的生成器类,并实现各种通用方法。 -
定义一个带有单个参数
device index
的新后端builder
。 -
调用
REGISTER_GENERATOR_PRIVATEUSE1
宏完成动态注册。
为新后端注册设备保护
PyTorch 通过 DeviceGuard
提供了与设备、流和事件切换相关的功能。此函数同样适用于 PrivateUse1
键。
-
继承
DeviceGuardImplInterface
类来实现与新后端对应的各种通用方法。 -
调用
C10_REGISTER_GUARD_IMPL
宏来完成动态注册。
为新后端元数据注册序列化和反序列化函数
PyTorch 目前能够动态注册序列化/反序列化函数,以支持在 TensorImpl.ExtraMeta
类中名为 backend_meta_
的新后端附加元数据的序列化和反序列化。您可以参考以下步骤:
-
继承
BackendMeta
类以实现与新后端对应的CustomBackendMetadata
,并在该类中自定义新后端的各种字段。 -
实现新后端的序列化和反序列化函数,函数签名为
void(const at::Tensor&, std::unordered_map<std::string, bool>&)
。 -
调用
TensorBackendMetaRegistry
宏完成动态注册。
其他模块
除了上述部分外,还有一些其他模块可以通过 PrivateUse1
进行扩展,例如 分布式集体通信
、基准测试计时器
等,这些模块将在未来逐步添加。关于 PrivateUse1
集成的一个例子是 Ascend NPU。
如何通过 Privateuse1 提升用户体验
通过 PrivateUse1
集成新设备的主要目标是满足基本功能需求,接下来要做的是提升可用性,这主要涉及以下几个方面。
-
将新的后端模块注册到 Pytorch。
-
将
PrivateUse1
重命名为新后端的自定义名称。 -
生成与新后端相关的方法和属性。
向 PyTorch 注册新的后端模块
PyTorch 中的一些 CUDA 相关接口可以通过以下形式调用:torch.cuda.xxx
。因此,为了适应用户习惯,通过 PrivateUse1
机制实现的新后端也应提供类似的接口。
例如,使用 Ascend NPU
:
在执行上述操作后,用户可以通过 torch.npu.xxx
调用 Ascend NPU
的一些专属 API。
将 PrivateUse1 重命名为新后端的自定义名称
PrivateUse1
键是集成到 PyTorch 中的新后端的内部机制。对于用户而言,与 PrivateUse1
相比,与新后端紧密相关的自定义名称会更加友好。
以 Ascend NPU
为例,第一种用法对用户来说会更加友好。
现在,PyTorch 为名为 PrivateUse1
的后端提供了一个新的 C++/Python API,使用起来非常简单。
生成与新后端相关的方法和属性
将 PrivateUse1
重命名为自定义名称后,自动在 Tensor
、nn
、Storage
模块中生成与新后端名称相关的属性和方法。
以下是以 Ascend NPU
为例的示例:
然后,您可以使用以下方法和属性:
未来工作
PrivateUse1
机制的改进仍在进行中,因此新模块的 PrivateUse1
集成方法将逐步添加。以下是我们正在积极处理的几项内容:
-
添加
分布式集体通信
的集成方法。 -
添加
基准计时器
的集成方法。
结论
本教程指导您通过 PrivateUse1
将新后端集成到 PyTorch 的过程,包括但不限于操作符注册、生成器注册、设备保护注册等。同时,还介绍了一些提升用户体验的方法。