torch.Tensor.register_post_accumulate_grad_hook

Tensor.register_post_accumulate_grad_hook(hook)[源代码]

注册一个在梯度累积之后运行的反向钩子。

当张量的所有梯度都累加完成后,会调用该钩子函数,此时张量的.grad属性已被更新。后累积梯度钩子仅适用于叶张量(即没有.grad_fn属性的张量)。在非叶张量上注册此钩子将引发错误!

钩子应具有如下定义:

hook(param: Tensor) -> None

请注意,与其他自动求导钩子不同,这个钩子操作的是需要计算梯度的张量,而不是梯度本身。它可以就地修改并访问其张量参数,包括.grad字段。

此函数返回一个包含 handle.remove() 方法的句柄,该方法可以用于从模块中移除钩子。

注意

有关此钩子何时执行及其与其他钩子的执行顺序,请参见Backward Hooks 执行。由于该钩子在反向传递期间运行,因此默认情况下会处于 no_grad 模式(除非 create_graph 为 True)。如果你需要启用自动微分功能,可以使用 torch.enable_grad() 在钩子内部重新开启 autograd。

示例:

>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
>>> lr = 0.01
>>> # simulate a simple SGD update
>>> h = v.register_post_accumulate_grad_hook(lambda p: p.add_(p.grad, alpha=-lr))
>>> v.backward(torch.tensor([1., 2., 3.]))
>>> v
tensor([-0.0100, -0.0200, -0.0300], requires_grad=True)

>>> h.remove()  # removes the hook
本页目录