torch.nn.utils.prune.is_pruned

torch.nn.utils.prune.is_pruned(module)[源代码]

通过查看剪枝预挂钩来检查模块是否被剪枝。

通过查找模块中继承自BasePruningMethodforward_pre_hooks来检查module是否已被剪枝。

参数

module (nn.Module) – 该对象可以是已经被剪枝的,也可以是没有被剪枝的

返回值

二进制答案,表示模块是否被剪枝。

示例

>>> from torch.nn.utils import prune
>>> m = nn.Linear(5, 7)
>>> print(prune.is_pruned(m))
False
>>> prune.random_unstructured(m, name='weight', amount=0.2)
>>> print(prune.is_pruned(m))
True
本页目录