torch.jit.freeze

torch.jit.freeze(mod, preserved_attrs=None, optimize_numerics=True)[源代码]

将 ScriptModule 冻结,并将其子模块和属性作为常量内联。

冻结一个ScriptModule 会将其克隆,并尝试将克隆模块的子模块、参数和属性作为常量内联到 TorchScript IR 图中。默认情况下,forward 方法以及在 preserved_attrs 中指定的其他属性和方法会被保留。此外,在被保留的方法内部修改的任何属性也会被保留。

目前,只有处于评估模式的ScriptModules可以被冻结。

冻结应用了通用优化,可以提升任何机器上模型的运行速度。为了进一步利用特定服务器的设置进行优化,请在冻结后运行optimize_for_inference

参数
  • mod (ScriptModule) - 一个要被冻结的模块

  • preserved_attrs (Optional[List[str]]) – 除了前向方法之外,要保留的属性列表。在被保留的方法中修改的属性也会被保留。

  • optimize_numerics (bool) – 如果为 True,将运行一组不严格保留数值的优化流程。有关详细信息,请参见torch.jit.run_frozen_optimizations

返回值

冻结的 ScriptModule

示例(使用参数冻结简单模块):

    def forward(self, input):
        output = self.weight.mm(input)
        output = self.linear(output)
        return output

scripted_module = torch.jit.script(MyModule(2, 3).eval())
frozen_module = torch.jit.freeze(scripted_module)
# parameters have been removed and inlined into the Graph as constants
assert len(list(frozen_module.named_parameters())) == 0
# See the compiled graph as Python code
print(frozen_module.code)

示例:带有保留属性的模块冻结

    def forward(self, input):
        self.modified_tensor += 1
        return input + self.modified_tensor

scripted_module = torch.jit.script(MyModule2().eval())
frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"])
# we've manually preserved `version`, so it still exists on the frozen module and can be modified
assert frozen_module.version == 1
frozen_module.version = 2
# `modified_tensor` is detected as being mutated in the forward, so freezing preserves
# it to retain model semantics
assert frozen_module(torch.tensor(1)) == torch.tensor(12)
# now that we've run it once, the next result will be incremented by one
assert frozen_module(torch.tensor(1)) == torch.tensor(13)

注意

也可以冻结子模块的属性:frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["submodule.version"])

注意

如果你不确定为什么某个属性没有被内联为常量,可以运行 dump_alias_db 查看 frozen_module.forward.graph,确认冻结过程中是否检测到该属性已被修改。

注意

由于冻结权重使其成为常量并移除模块层次结构,to 和其他 nn.Module 方法用于更改设备或数据类型不再起作用。作为替代方案,您可以在 torch.jit.load 中通过指定 map_location 来重新映射设备,但模型中可能已经嵌入了特定于设备的逻辑。

本页目录