torch.utils.rename_privateuse1_backend

torch.utils.rename_privateuse1_backend(backend_name)[源代码]

将privateuse1后端设备重命名,使其在PyTorch API中更便于用作设备名称。

步骤是:

  1. (在 C++ 中)实现各种 torch 操作的内核,并将其注册到 PrivateUse1 分派键。

  2. 在 Python 中调用 torch.utils.rename_privateuse1_backend("foo")

你现在可以将“foo”作为普通设备字符串使用了。

注意:此API每个进程只能调用一次。如果在外部后端已设置后再进行更改,将会引发错误。

Note(AMP): 如果你想在设备上支持 AMP,可以注册一个自定义后端模块。你需要使用torch._register_device_module("foo", BackendModule) 来注册该模块,并且 BackendModule 需要提供以下 API。

  1. get_amp_supported_dtype() -> List[torch.dtype] 获取你在“foo”设备上自动混合精度(AMP)模式下支持的数据类型,可能该设备还支持其他数据类型。

注意:如果你想为你设备的种子设置功能提供支持,BackendModule 需要包含以下 API:

  1. _is_in_bad_fork() 如果当前处于 bad_fork 状态,则返回 True,否则返回 False

  2. manual_seed_all(seed int) -> None 设置生成随机数的种子,以便应用于您的设备。

  3. device_count() 返回可用的“foo”数量。

  4. get_rng_state(device: Union[int, str, torch.device] = 'foo') -> Tensor 返回一个包含所有设备随机数状态的ByteTensor列表。

  5. set_rng_state(new_state: Tensor, device: Union[int, str, torch.device] = 'foo') -> None 设置指定设备(如 "foo")的随机数生成器状态。

还有一些常用的函数:
  1. is_available() -> bool 返回一个布尔值,指示“foo”当前是否可用。

  2. current_device() 返回当前选定设备的索引(类型为 int)。

更多详细信息,请参阅https://pytorch.org/tutorials/advanced/extend_dispatcher.html#get-a-dispatch-key-for-your-backend。一个现有示例,请参见 https://github.com/bdhirsh/pytorch_open_registration_example

示例:

>>> torch.utils.rename_privateuse1_backend("foo")
# This will work, assuming that you've implemented the right C++ kernels
# to implement torch.ones.
>>> a = torch.ones(2, device="foo")
本页目录