torch.__future__

torch.__future__.set_overwrite_module_params_on_conversion(value)[源代码]

设置在将 nn.Module 转换时,使用新张量替换现有参数而不是直接修改现有参数。

启用后,以下方法将为模块分配新的参数:

  1. module.{device}()(例如 nn.Module.cuda())用于将模块在不同的设备之间进行移动

  2. module.{dtype}()(例如 nn.Module.float())用于将模块转换为不同数据类型

  3. nn.Module.to()

  4. nn.Module.to_empty()

参数

value (bool) – 是否分配新张量的标志。

torch.__future__.get_overwrite_module_params_on_conversion()[源代码]

在将 torch.nn.Module 转换时,返回是否为参数分配新的张量而不是就地更改现有参数。默认值为 False

更多详情请参见set_overwrite_module_params_on_conversion()

返回类型

bool

torch.__future__.set_swap_module_params_on_conversion(value)[源代码]

设置是否在将nn.Module转换为其他形式时,使用swap_tensors()来替换原本通过设置.data属性原地更改现有参数的方法;以及在加载状态字典到nn.Module时,使用param.copy_(state_dict[key])

注意

此函数的优先级高于 get_overwrite_module_params_on_conversion()

启用后,以下方法将就地替换现有的参数:

  1. module.{device}()(例如 nn.Module.cuda())用于将模块在不同的设备之间进行移动

  2. module.{dtype}()(例如 nn.Module.float())用于将模块转换为不同数据类型

  3. nn.Module.to()

  4. nn.Module.to_empty()

  5. nn.Module.load_state_dict()

当启用此选项时,load_state_dict() 的行为如下:

  1. 对于每个参数/缓冲区,其对应的 state_dict['key'] 会通过module_load() 方法进行转换(即执行 res = param.module_load(state_dict['key']))。

  2. 如果必要,res 将会被包裹在一个 Parameter 中。

  3. 模块中的参数/缓冲区将通过swap_tensors()res进行交换

参数

value (bool) – 是否使用 swap_tensors()

torch.__future__.get_swap_module_params_on_conversion()[源代码]

返回是否在将 nn.Module 转换时使用swap_tensors() 替代设置 .data 来就地更改现有参数。默认值为 False

更多详情请参见set_swap_module_params_on_conversion()

返回类型

bool

本页目录