torch.nn.utils.parametrize.cached
- torch.nn.utils.parametrize.cached()[源代码]
-
一个上下文管理器,用于启用在
register_parametrization()
注册的参数化中的缓存系统。当此上下文管理器激活时,参数化对象的值会在第一次使用时进行计算并缓存。离开上下文管理器时,这些缓存的值会被清除。
这在正向传递过程中多次使用参数化的参数时非常有用,例如在为RNN的循环核进行参数化或共享权重时。
激活缓存的最简单方法是通过包裹神经网络的前向传递。
import torch.nn.utils.parametrize as P ... with P.cached(): output = model(inputs)
在训练和评估过程中,可以封装那些多次使用参数化张量的模块部分。例如,具有参数化递归核的RNN循环:
with P.cached(): for x in xs: out_rnn = self.rnn_cell(x, out_rnn)