torch.nn.modules.module.register_module_forward_hook
- torch.nn.modules.module.register_module_forward_hook(hook, *, always_call=False)[源代码]
-
为所有的模块注册一个全局前进钩子。
警告
这在nn.module模块中添加了全局状态,仅用于调试和性能分析。
每次
forward()
计算出输出后,都会调用这个钩子。它应具有以下签名:hook(module, input, output) -> None or modified output
输入仅包含传递给模块的位置参数,关键字参数不会传递到钩子,只会传递给
forward
方法。钩子可以修改输出结果,并且可以就地修改输入数据,但这样做在前向传播中不会有影响,因为钩子是在调用forward()
之后才被调用的。- 参数
-
-
hook (Callable) – 用户自定义的要注册的钩子。
-
always_call (bool) – 如果为
True
,则在调用 Module 时无论是否引发异常,hook
都会被执行。默认值:False
-
- 返回值
-
一个可以通过调用
handle.remove()
移除添加的钩子的句柄 - 返回类型
-
torch.utils.hooks.RemovableHandle
此钩子将在使用
register_forward_hook
注册的特定模块钩子之前执行。