torch.nn.utils.parametrize.remove_parametrizations

torch.nn.utils.parametrize.remove_parametrizations(module, tensor_name, leave_parametrized=True)[源代码]

去掉模块中张量的参数化。

  • 如果 leave_parametrized=True,则将 module[tensor_name] 设置为它的当前输出值。此时,参数化不应更改张量的 dtype

  • 如果 leave_parametrized=False,则将 module[tensor_name] 设置为 module.parametrizations[tensor_name].original 中的非参数化张量。这仅在参数化只依赖于一个张量时才可行。

参数
  • module (nn.Module) – 需要移除参数化的模块

  • tensor_name (str) – 需要移除的参数化名称

  • leave_parametrized (bool, 可选) – 是否保留属性 tensor_name 的参数化形式。默认值:True

返回值

模块

返回类型

模块

异常
  • ValueError – 如果 module[tensor_name] 未进行参数化

  • ValueError – 当 leave_parametrized=False 且参数化依赖于多个张量时引发

本页目录