torch.nn.utils.prune.custom_from_mask

torch.nn.utils.prune.custom_from_mask(module, name, mask)[源代码]

使用mask中的预计算掩码来修剪module中名为name的参数对应的张量。

通过以下方式就地修改模块(并将修改后的模块返回):

  1. 添加一个名为name+'_mask'的缓冲区,该缓冲区对应于剪枝方法应用于参数name的二进制掩码。

  2. 将参数 name 替换为其修剪后的版本,而原始(未修剪)参数则存储在一个名为 name+'_orig' 的新参数中。

参数
  • module (nn.Module) – 包含待修剪张量的模块

  • name (str) – 进行剪枝操作的模块中的参数名称。

  • mask (张量) – 应用于参数的二值掩码。

返回值

输入模块的修改版(即简化后的版本)

返回类型

模块(nn.Module

示例

>>> from torch.nn.utils import prune
>>> m = prune.custom_from_mask(
...     nn.Linear(5, 3), name='bias', mask=torch.tensor([0, 1, 0])
... )
>>> print(m.bias_mask)
tensor([0., 1., 0.])
本页目录