torch.cuda.jiterator._create_multi_output_jit_fn

torch.cuda.jiterator._create_multi_output_jit_fn(code_string, num_outputs, **kwargs)[源代码]

为一个元素级操作创建一个由jiterator生成的支持返回一个或多个输出的CUDA内核。

参数
  • code_string (str) – 需要由jiterator编译的CUDA代码字符串。入口函数必须通过引用来返回值。

  • num_outputs (int) – 内核返回的输出数量

  • kwargs (Dict, 可选) – 用于生成函数的关键词参数

返回类型

Callable

示例:

code_string = "template <typename T> void my_kernel(T x, T y, T alpha, T& out) { out = -x + alpha * y; }"
jitted_fn = create_jit_fn(code_string, alpha=1.0)
a = torch.rand(3, device='cuda')
b = torch.rand(3, device='cuda')
# invoke jitted function like a regular python function
result = jitted_fn(a, b, alpha=3.14)

警告

此 API 处于 beta 阶段,未来版本可能有所更改。

警告

此 API 最多支持 8 个输入和 8 个输出。

本页目录