torch.cuda.jiterator._create_jit_fn
- torch.cuda.jiterator._create_jit_fn(code_string, **kwargs)[源代码]
-
为一个元素级操作创建一个由jiterator生成的CUDA内核。
代码字符串必须是一个有效的CUDA函数,用于描述单个元素的计算过程。该字符串需要遵循C++模板模式,如下面的例子所示。此函数将被内联到逐元素内核模板中,并在运行时进行编译。编译后的内核将在内存和本地临时目录中缓存。
Jiterator生成的内核可以处理非连续张量,并支持广播和类型提升功能。
- 参数
-
-
code_string (str) – 需要由 jiterator 编译的 CUDA 代码字符串。入口函数必须通过值返回。
-
kwargs (Dict, 可选) – 用于生成函数的关键词参数
-
- 返回类型
示例:
code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }" jitted_fn = create_jit_fn(code_string, alpha=1.0) a = torch.rand(3, device='cuda') b = torch.rand(3, device='cuda') # invoke jitted function like a regular python function result = jitted_fn(a, b, alpha=3.14)
code_string 允许定义多个函数,其中最后一个函数会被视为入口函数。
示例:
code_string = "template <typename T> T util_fn(T x, T y) { return ::sin(x) + ::cos(y); }" code_string += "template <typename T> T my_kernel(T x, T y, T val) { return ::min(val, util_fn(x, y)); }" jitted_fn = create_jit_fn(code_string, val=0.0) a = torch.rand(3, device='cuda') b = torch.rand(3, device='cuda') # invoke jitted function like a regular python function result = jitted_fn(a, b) # using default val=0.0
Jiterator 可以与 Python 注册结合使用,来覆盖操作符的 CUDA 内核。下面的例子展示如何用 relu 覆盖 gelu 的 CUDA 内核。
示例:
code_string = "template <typename T> T my_gelu(T a) { return a > 0 ? a : 0; }" my_gelu = create_jit_fn(code_string) my_lib = torch.library.Library("aten", "IMPL") my_lib.impl('aten::gelu', my_gelu, "CUDA") # torch.nn.GELU and torch.nn.function.gelu are now overridden a = torch.rand(3, device='cuda') torch.allclose(torch.nn.functional.gelu(a), torch.nn.functional.relu(a))
警告
此 API 处于 beta 阶段,未来版本可能有所更改。警告
此 API 最多支持 8 个输入和 1 个输出。
警告
所有的输入张量都必须在CUDA设备上