PyTorch 2.0 神经网络模块支持

作者: Will Constable

torch.compile 对 torch.nn.Module 对象进行特殊处理,不同于对普通 Python 类的处理方式,它通过对结构做出假设来生成更快速的代码。

本文档描述了由于这种专业化所产生的一些权衡和边缘情况。

NNModule 钩子支持

此前,torch.compile 不支持对 nn.Modules 的钩子(hooks),如果注册了钩子,在编译后的程序中这些钩子将被忽略。确实,许多用户根本不使用 nn.Module 钩子,或者仅在调试工作流程中使用它们,但将 nn.Module 钩子与 torch.compile 结合使用的有效用例是存在的。

通过 nn.Module.__call__ 实现的钩子包括 _forward_pre_hooksforward_hooks_backward_pre_hooks_backward_hooks,这些钩子将被称为“调用钩子”。这些钩子部分得到了torch.compile的支持,但存在一些限制。

另一类钩子包括_state_dict_hooks及其preload_变体,但这些目前还不受torch.compile的支持。

nn.Module.__call__ 钩子的使用方法及限制

默认情况下,torch.compile 会追踪 nn.Module.__call__ 的内容,这意味着它会遇到并运行前向和预前向钩子。如果你在调用 torch.compile 之前安装了这些钩子,并且之后没有移除或修改它们,那么默认情况下你的使用场景应该可以得到支持。

Backward/Pre-backward挂钩通常也受支持,但有类似的限制:目前在dynamo中访问backward_hooks字典时会发生图中断,这可能通过一些工作来避免。图中断还会影响反向钩子的触发时间,因为每个图段作为autograd函数运行,并在同一时间生成所有梯度。假设dynamo能够在存在backward-hooks的情况下不发生图中断,我们仍然期望一系列模块的反向钩子在编译后的整个图完成反向传播后一起触发。

‘允许的模块’钩子 torch.compile 将常见的模块(如 torch.conv)以及难以追踪的模块视为特殊情况,允许它们在 dynamo 图中以不透明的方式调用,而不是由 dynamo 进行追踪。对于这些模块,当前的钩子会触发一个图中断,使得受影响的模块在 dynamo 之外运行。这可能会导致性能下降,并且需要额外的工作来改进这种支持。

skip_nnmodule_hook_guards 默认情况下,torch._dynamo.config.skip_nnmodule_hook_guards 设置为 True。这意味着不会在每个 nn.Module 的钩子字典上安装保护措施,从而通过减少保护执行时间来提高运行时性能。然而,这也意味着无法检测到编译后对任何钩子字典的更改。

如果你想在编译后移除或修改钩子,并希望 torch.compile 相应地重新编译,则需要将 skip_nnmodule_hook_guards 设置为 false,并预期因添加的保护措施而产生的运行时开销。

待办事项:确认 backward 和 pre_backward 挂钩是否正常工作,并据此更新文档。

state_dict钩子

当前,torch.compile不支持状态字典挂钩。

待办事项:如果挂钩中断了图表,请使用 warn_once 进行一次警告。如果有挂钩存在,请使用 warn_once 指向此文档。

本页目录