torch.ao.ns._numeric_suite

警告

此模块为早期原型,可能随时更改。

torch.ao.ns._numeric_suite.compare_weights(float_dict, quantized_dict)[源代码]

比较浮点模块与其对应的量化模块的权重,并返回一个字典。该字典的键为模块名称,每个条目的值是一个包含“float”和“quantized”两个键的字典,分别存储浮点和量化的权重。此字典可用于比较并计算浮点模型与量化模型之间权重的量化误差。

示例用法:

wt_compare_dict = compare_weights(
    float_model.state_dict(), qmodel.state_dict())
for key in wt_compare_dict:
    print(
        key,
        compute_error(
            wt_compare_dict[key]['float'],
            wt_compare_dict[key]['quantized'].dequantize()
        )
    )
参数
  • float_dict (Dict[str, Any]) – 浮点模型的状态字典

  • quantized_dict (Dict[str, Any]) – 量化模型的状态字典

返回值

一个字典,其中键是模块名称,每个值是一个包含两个键 'float' 和 'quantized' 的字典,分别存储浮点和量化的权重。

返回类型

权重字典 (weight_dict)

torch.ao.ns._numeric_suite.get_logger_dict(mod, prefix='')[源代码]

遍历各个模块,并将所有的日志统计信息保存到目标字典中,主要用于量化精度的调试。

支持的日志类型:

ShadowLogger:用于记录量化模块及其对应的浮点影子模块的输出,OutputLogger:用于记录各个模块的输出

参数
  • mod (Module) – 需要保存所有日志统计信息的模块

  • prefix (str) – 模块的前缀

返回值

保存所有日志器统计信息的字典

返回类型

目标字典

torch.ao.ns._numeric_suite.Logger[源代码]

统计日志的基础类

forward(x)[源代码]
torch.ao.ns._numeric_suite.ShadowLogger[源代码]

用于在Shadow模块中记录原模块和影子模块输出的类。

forward(x, y)[源代码]
torch.ao.ns._numeric_suite.OutputLogger[源代码]

用于记录模块输出的日志类

forward(x)[源代码]
classtorch.ao.ns._numeric_suite.Shadow(q_module, float_module, logger_cls)[源代码]

Shadow 模块将浮点模块连接到相应的量化模块上作为影子,然后使用 Logger 模块处理这两个模块的输出。

参数
  • q_module — 由 float_module 量化而来的一个模块,我们要为它生成一个替代版本

  • float_module - 用于替代 q_module 的浮点模块

  • logger_cls - 处理 q_module 和 float_module 输出的日志记录器类型。可以使用 ShadowLogger 或自定义日志记录器。

forward(*x)[源代码]
返回类型

Tensor

add(x, y)[源代码]
返回类型

Tensor

add_scalar(x, y)[源代码]
返回类型

Tensor

mul(x, y)[源代码]
返回类型

Tensor

mul_scalar(x, y)[源代码]
返回类型

Tensor

cat(x, dim=0)[源代码]
返回类型

Tensor

add_relu(x, y)[源代码]
返回类型

Tensor

torch.ao.ns._numeric_suite.prepare_model_with_stubs(float_module, q_module, module_swap_list, logger_cls)[源代码]

通过将浮点模块附加到其对应的量化模块作为影子来准备模型,前提是该浮点模块类型存在于module_swap_list中。

示例用法:

prepare_model_with_stubs(float_model, q_model, module_swap_list, Logger)
q_model(data)
ob_dict = get_logger_dict(q_model)
参数
  • float_module (Module) – 用于生成 q_module 的浮点模块

  • q_module (Module) – 由 float_module 量化而来的模块

  • module_swap_list (Set[type]) – 一个浮点模块类型的集合,用于附加影子

  • logger_cls (Callable) – 在影子模块中用于处理量化模块及其浮点影子模块输出的日志类类型

torch.ao.ns._numeric_suite.compare_model_stub(float_model, q_model, module_swap_list, *data, logger_cls=<class 'torch.ao.ns._numeric_suite.ShadowLogger'>)[源代码]

将模型中的量化模块与其对应的浮点模块进行比较,并给两者提供相同的输入。返回一个字典,其中键对应模块名称,每个条目包含两个键“float”和“quantized”,分别存储量化模块及其匹配的浮点影子模块的输出张量。此字典可用于比较并计算模块级别的量化误差。

此函数首先调用 prepare_model_with_stubs() 方法来替换我们想要比较的量化模块与 Shadow 模块。该方法接受量化模块、对应的浮点模块以及日志记录器作为输入,并在内部创建一个前向路径,使浮点模块能够影子跟踪量化模块并共享相同的输入。日志记录器可以自定义,默认使用 ShadowLogger,它会保存量化模块和浮点模块的输出,这些输出可用于计算模块级别的量化误差。

示例用法:

module_swap_list = [torchvision.models.quantization.resnet.QuantizableBasicBlock]
ob_dict = compare_model_stub(float_model,qmodel,module_swap_list, data)
for key in ob_dict:
    print(key, compute_error(ob_dict[key]['float'], ob_dict[key]['quantized'].dequantize()))
参数
  • float_model (Module) – 用于生成 q_model 的浮点模型

  • q_model (Module) – 由 float_model 量化而来的模型

  • module_swap_list (Set[type]) – 在这些浮点模块类型的位置附加影子模块。

  • data — 用于运行预准备的 q_model 的输入数据

  • logger_cls — 在影子模块中使用的日志记录器类型,用于处理量化模块及其浮点影子模块的输出。

返回类型

Dict[str, Dict]

torch.ao.ns._numeric_suite.get_matching_activations(float_module, q_module)[源代码]

找到浮点模块和量化模块之间匹配的激活。

参数
  • float_module (Module) – 用于生成 q_module 的浮点模块

  • q_module (Module) – 由 float_module 量化而来的模块

返回值

这是一个字典,其中键是量化模块的名称,每个条目的值是一个包含两个键“float”和“quantized”的字典,分别存储对应的浮点和量化激活值。

返回类型

act_dict

torch.ao.ns._numeric_suite.prepare_model_outputs(float_module, q_module, logger_cls=<class 'torch.ao.ns._numeric_suite.OutputLogger'>, allow_list=None)[源代码]

通过将日志记录器同时附加到浮点模块和量化模块(如果它们在允许列表中),来准备模型。

参数
  • float_module (Module) – 用于生成 q_module 的浮点模块

  • q_module (Module) – 由 float_module 量化而来的模块

  • logger_cls - 要附加到 float_module 和 q_module 的日志器的类型

  • allow_list - 指定要为其添加日志记录器的模块类型的列表

torch.ao.ns._numeric_suite.compare_model_outputs(float_model, q_model, *data, logger_cls=<class 'torch.ao.ns._numeric_suite.OutputLogger'>, allow_list=None)[源代码]

比较相同输入在浮点模型和量化模型中对应位置的输出激活情况。返回一个字典,其中键为量化模块名称,每个条目包含两个键“float”和“quantized”,分别存储匹配位置处量化模型和浮点模型的激活值。此字典可用于比较并计算传播过程中的量化误差。

示例用法:

act_compare_dict = compare_model_outputs(float_model, qmodel, data)
for key in act_compare_dict:
    print(
        key,
        compute_error(
            act_compare_dict[key]['float'],
            act_compare_dict[key]['quantized'].dequantize()
        )
    )
参数
  • float_model (Module) – 用于生成 q_model 的浮点模型

  • q_model (Module) – 由 float_model 量化而来的模型

  • data — 用于运行预设的浮点模型和量化模型的输入数据

  • logger_cls - 要附加到 float_module 和 q_module 的日志器的类型

  • allow_list - 指定要为其添加日志记录器的模块类型的列表

返回值

这是一个字典,其中键是量化模块的名称,每个条目的值是一个包含两个键“float”和“quantized”的字典,分别存储对应的浮点和量化激活值。

返回类型

act_compare_dict

本页目录