torch.nn.utils.fuse_linear_bn_weights

torch.nn.utils.fuse_linear_bn_weights(linear_w, linear_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b)[源代码]

将线性模块参数与批处理规范化模块参数合并为新的线性模块参数。

参数
返回值

将线性权重和偏置融合在一起。

返回类型

(torch.nn.Parameter, torch.nn.Parameter)

本页目录