torch.cuda.graph(cuda_graph, pool=None, stream=None, capture_error_mode='global') [源代码]

这是一个上下文管理器,它将CUDA工作捕获到一个torch.cuda.CUDAGraph对象中,以便后续重播。

参见CUDA 图以获取一般介绍、详细用法和约束条件。

参数
  • cuda_graph (torch.cuda.CUDAGraph) – 用于捕捉操作的图对象。

  • pool (可选) – 由 graph_pool_handle()other_Graph_instance.pool() 返回的不透明标记,提示此图可能从指定池中共享内存。参见Graph 内存管理

  • stream (torch.cuda.Stream, 可选) – 如果提供,则在上下文中设置为当前流。如果没有提供,默认情况下,graph 在上下文中将其内部的辅助流设置为当前流。

  • capture_error_mode (str, 可选) – 指定用于图捕获流的 cudaStreamCaptureMode。可以是“global”、“thread_local”或“relaxed”。在 CUDA 图形捕获期间,某些操作(如 cudaMalloc)可能是不安全的。“global”会在其他线程的操作时出错,“thread_local”仅在当前线程中的操作时出错,并且“relaxed”不会因任何操作而出错。除非你熟悉 cudaStreamCaptureMode,否则请勿更改此设置。

注意

为了有效共享内存,如果你传递一个由之前捕获操作使用的pool,而该之前的捕获操作使用了显式的stream参数,则你应该向此次捕获传递相同的stream参数。

警告

此 API 处于 beta 阶段,未来版本可能有所更改。
本页目录