fuse_modules
- class torch.ao.quantization.fuse_modules.fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=<function fuse_known_modules>, fuse_custom_config_dict=None)[源代码]
-
将一系列模块合并成一个单一模块。
只融合以下模块序列:conv, bn conv, bn, relu conv, relu linear, relu bn, relu。其他所有序列保持不变。对于这些序列,用融合后的模块替换列表中的第一个模块,并将其余模块替换为身份模块。
- 参数
-
-
model – 包含需要融合的模块的模型
-
modules_to_fuse - 要融合的模块名称列表的列表。如果有且仅有一个要融合的模块列表,则该参数也可以是字符串列表。
-
inplace - 指定融合是否在模型的原地进行,否则将返回一个新的模型。
-
fuser_func — 一个接受模块列表并返回相同长度的融合模块列表的函数。例如,fuser_func([convModule, BNModule]) 返回 [ConvBNModule, nn.Identity()]。默认值为 torch.ao.quantization.fuse_known_modules。
-
fuse_custom_config_dict - 自定义融合设置
-
# Example of fuse_custom_config_dict fuse_custom_config_dict = { # Additional fuser_method mapping "additional_fuser_method_mapping": { (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn }, }
- 返回值
-
具有融合模块的模型。如果 inplace=True,将创建一个新副本。
示例:
>>> m = M().eval() >>> # m is a module containing the sub-modules below >>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']] >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse) >>> output = fused_m(input) >>> m = M().eval() >>> # Alternately provide a single list of modules to fuse >>> modules_to_fuse = ['conv1', 'bn1', 'relu1'] >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse) >>> output = fused_m(input)