torch.jit.optimize_for_inference

torch.jit.optimize_for_inference(mod, other_methods=None)[源代码]

进行一系列优化操作,以便将模型优化用于推理目的。

如果模型未被冻结,optimize_for_inference 将会自动调用 torch.jit.freeze

除了能够在任何环境中加快模型速度的通用优化之外,准备推理还会包含特定于构建的设置,例如是否存在 CUDNN 或 MKLDNN。未来可能会进行一些转换,这些转换在一台机器上可以加速但在另一台机器上可能减速。因此,在调用optimize_for_inference之后不会实现序列化,并且不能保证其效果。

这仍然处于原型阶段,可能会影响你的模型运行速度。目前主要的应用场景是在CPU和GPU上运行视觉模型,但对GPU的影响相对较小。

示例(通过Conv->Batchnorm优化模块):

import torch
in_channels, out_channels = 3, 32
conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
mod = torch.nn.Sequential(conv, bn)
frozen_mod = torch.jit.optimize_for_inference(torch.jit.script(mod.eval()))
assert "batch_norm" not in str(frozen_mod.graph)
# if built with MKLDNN, convolution will be run with MKLDNN weights
assert "MKLDNN" in frozen_mod.graph
返回类型

ScriptModule

本页目录