torch.nn.utils.parametrize.cached

torch.nn.utils.parametrize.cached()[源代码]

一个上下文管理器,用于启用在 register_parametrization() 注册的参数化中的缓存系统。

当此上下文管理器激活时,参数化对象的值会在第一次使用时进行计算并缓存。离开上下文管理器时,这些缓存的值会被清除。

这在正向传递过程中多次使用参数化的参数时非常有用,例如在为RNN的循环核进行参数化或共享权重时。

激活缓存的最简单方法是通过包裹神经网络的前向传递。

import torch.nn.utils.parametrize as P
...
with P.cached():
    output = model(inputs)

在训练和评估过程中,可以封装那些多次使用参数化张量的模块部分。例如,具有参数化递归核的RNN循环:

with P.cached():
    for x in xs:
        out_rnn = self.rnn_cell(x, out_rnn)
本页目录