torch.nn.utils.prune.identity

torch.nn.utils.prune.identity(module, name)[源代码]

进行剪枝重新参数化,但不修剪任何单元。

module中名为name的参数对应的张量应用剪枝重新参数化,但不实际删除任何单元。通过以下方式就地修改模块,并返回修改后的模块:

  1. 添加一个名为name+'_mask'的缓冲区,该缓冲区对应于剪枝方法应用于参数name的二进制掩码。

  2. 将参数 name 替换为其修剪后的版本,而原始(未修剪)参数则存储在一个名为 name+'_orig' 的新参数中。

注意

掩码是一个由全1组成的张量。

参数
  • module (nn.Module) – 包含待修剪张量的模块。

  • name (str) – 进行剪枝操作的模块中的参数名称。

返回值

输入模块的修改版(即简化后的版本)

返回类型

模块(nn.Module

示例

>>> m = prune.identity(nn.Linear(2, 3), 'bias')
>>> print(m.bias_mask)
tensor([1., 1., 1.])
本页目录