PyTorch 入门指南
学习 PyTorch
图像和视频
音频
后端
强化学习
在生产环境中部署 PyTorch 模型
Profiling PyTorch
代码变换与FX
前端API
扩展 PyTorch
模型优化
并行和分布式训练
边缘端的 ExecuTorch
推荐系统
多模态

剪枝教程

作者: Michela Paganini

最先进的深度学习技术依赖于难以部署的过度参数化模型。相反,生物神经网络以使用高效的稀疏连接而闻名。为了在不牺牲准确性的情况下减少内存、电池和硬件消耗,识别通过减少参数数量来压缩模型的最佳技术至关重要。这反过来使您能够在设备上部署轻量级模型,并通过设备上的私有计算来确保隐私。在研究前沿,剪枝被用于研究过度参数化和欠参数化网络之间学习动态的差异,研究幸运稀疏子网络和初始化的作用(“彩票假设”)作为一种破坏性的神经架构搜索技术,等等。

在本教程中,您将学习如何使用 torch.nn.utils.prune 来稀疏化您的神经网络,以及如何扩展它以实现您自己的自定义剪枝技术。

环境要求

"torch>=1.4.0a0+8e8a5e0"

importtorch
fromtorchimport nn
importtorch.nn.utils.pruneasprune
importtorch.nn.functionalasF

创建模型

在本教程中,我们使用 LeCun 等人在 1998 年提出的 LeNet 架构。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

classLeNet(nn.Module):
    def__init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    defforward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

检查模块

让我们检查一下 LeNet 模型中(未经修剪的)conv1 层。目前,它将包含两个参数 weightbias,并且没有缓冲区。

module = model.conv1
print(list(module.named_parameters()))
[('weight', Parameter containing:
tensor([[[[ 0.1529,  0.1660, -0.0469,  0.1837, -0.0438],
          [ 0.0404, -0.0974,  0.1175,  0.1763, -0.1467],
          [ 0.1738,  0.0374,  0.1478,  0.0271,  0.0964],
          [-0.0282,  0.1542,  0.0296, -0.0934,  0.0510],
          [-0.0921, -0.0235, -0.0812,  0.1327, -0.1579]]],


        [[[-0.0922, -0.0565, -0.1203,  0.0189, -0.1975],
          [ 0.1806, -0.1699,  0.1544,  0.0333, -0.0649],
          [ 0.1236,  0.0312,  0.1616,  0.0219, -0.0631],
          [ 0.0537, -0.0542,  0.0842,  0.1786,  0.1156],
          [-0.0874,  0.1155,  0.0358,  0.1016, -0.1219]]],


        [[[-0.1980, -0.0773, -0.1534,  0.1641,  0.0576],
          [ 0.0828,  0.0633, -0.0035,  0.1565, -0.1421],
          [ 0.0126, -0.1365,  0.0617, -0.0689,  0.0613],
          [-0.0417,  0.1659, -0.1185, -0.1193, -0.1193],
          [ 0.1799,  0.0667,  0.1925, -0.1651, -0.1984]]],


        [[[-0.1565, -0.1345,  0.0810,  0.0716,  0.1662],
          [-0.1033, -0.1363,  0.1061, -0.0808,  0.1214],
          [-0.0475,  0.1144, -0.1554, -0.1009,  0.0610],
          [ 0.0423, -0.0510,  0.1192,  0.1360, -0.1450],
          [-0.1068,  0.1831, -0.0675, -0.0709, -0.1935]]],


        [[[-0.1145,  0.0500, -0.0264, -0.1452,  0.0047],
          [-0.1366, -0.1697, -0.1101, -0.1750, -0.1273],
          [ 0.1999,  0.0378,  0.0616, -0.1865, -0.1314],
          [-0.0666,  0.0313, -0.1760, -0.0862, -0.1197],
          [ 0.0006, -0.0744, -0.0139, -0.1355, -0.1373]]],


        [[[-0.1167, -0.0685, -0.1579,  0.1677, -0.0397],
          [ 0.1721,  0.0623, -0.1694,  0.1384, -0.0550],
          [-0.0767, -0.1660, -0.1988,  0.0572, -0.0437],
          [ 0.0779, -0.1641,  0.1485, -0.1468, -0.0345],
          [ 0.0418,  0.1033,  0.1615,  0.1822, -0.1586]]]], device='cuda:0',
       requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.0503, -0.0860, -0.0219, -0.1497,  0.1822, -0.1468], device='cuda:0',
       requires_grad=True))]
print(list(module.named_buffers()))
[]

模块剪枝

要剪枝一个模块(在本例中,LeNet 架构的 conv1 层),首先从 torch.nn.utils.prune 中提供的剪枝技术中选择一种(或通过子类化 BasePruningMethod实现您自己的剪枝方法)。然后,指定模块以及该模块中要剪枝的参数的名称。最后,使用所选剪枝技术所需的关键字参数,指定剪枝参数。

在本例中,我们将随机剪枝 conv1 层中名为 weight 的参数的 30% 连接。模块作为函数的第一个参数传递;name 使用其字符串标识符标识该模块中的参数;amount 表示要剪枝的连接的比例(如果它是 0 到 1 之间的浮点数),或者要剪枝的连接的绝对数量(如果它是非负整数)。

prune.random_unstructured(module, name="weight", amount=0.3)
Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

剪枝操作通过从参数中移除 weight,并将其替换为名为 weight_orig 的新参数(即在初始参数 name 后附加 "_orig")。weight_orig 存储了张量的未剪枝版本。bias 未被剪枝,因此它将保持不变。

print(list(module.named_parameters()))
[('bias', Parameter containing:
tensor([ 0.0503, -0.0860, -0.0219, -0.1497,  0.1822, -0.1468], device='cuda:0',
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.1529,  0.1660, -0.0469,  0.1837, -0.0438],
          [ 0.0404, -0.0974,  0.1175,  0.1763, -0.1467],
          [ 0.1738,  0.0374,  0.1478,  0.0271,  0.0964],
          [-0.0282,  0.1542,  0.0296, -0.0934,  0.0510],
          [-0.0921, -0.0235, -0.0812,  0.1327, -0.1579]]],


        [[[-0.0922, -0.0565, -0.1203,  0.0189, -0.1975],
          [ 0.1806, -0.1699,  0.1544,  0.0333, -0.0649],
          [ 0.1236,  0.0312,  0.1616,  0.0219, -0.0631],
          [ 0.0537, -0.0542,  0.0842,  0.1786,  0.1156],
          [-0.0874,  0.1155,  0.0358,  0.1016, -0.1219]]],


        [[[-0.1980, -0.0773, -0.1534,  0.1641,  0.0576],
          [ 0.0828,  0.0633, -0.0035,  0.1565, -0.1421],
          [ 0.0126, -0.1365,  0.0617, -0.0689,  0.0613],
          [-0.0417,  0.1659, -0.1185, -0.1193, -0.1193],
          [ 0.1799,  0.0667,  0.1925, -0.1651, -0.1984]]],


        [[[-0.1565, -0.1345,  0.0810,  0.0716,  0.1662],
          [-0.1033, -0.1363,  0.1061, -0.0808,  0.1214],
          [-0.0475,  0.1144, -0.1554, -0.1009,  0.0610],
          [ 0.0423, -0.0510,  0.1192,  0.1360, -0.1450],
          [-0.1068,  0.1831, -0.0675, -0.0709, -0.1935]]],


        [[[-0.1145,  0.0500, -0.0264, -0.1452,  0.0047],
          [-0.1366, -0.1697, -0.1101, -0.1750, -0.1273],
          [ 0.1999,  0.0378,  0.0616, -0.1865, -0.1314],
          [-0.0666,  0.0313, -0.1760, -0.0862, -0.1197],
          [ 0.0006, -0.0744, -0.0139, -0.1355, -0.1373]]],


        [[[-0.1167, -0.0685, -0.1579,  0.1677, -0.0397],
          [ 0.1721,  0.0623, -0.1694,  0.1384, -0.0550],
          [-0.0767, -0.1660, -0.1988,  0.0572, -0.0437],
          [ 0.0779, -0.1641,  0.1485, -0.1468, -0.0345],
          [ 0.0418,  0.1033,  0.1615,  0.1822, -0.1586]]]], device='cuda:0',
       requires_grad=True))]

由上述选择的剪枝技术生成的剪枝掩码将作为名为 weight_mask 的模块缓冲区保存(即在初始参数 name 后追加 "_mask")。

print(list(module.named_buffers()))
[('weight_mask', tensor([[[[1., 1., 1., 1., 1.],
          [1., 0., 1., 1., 1.],
          [1., 0., 0., 1., 1.],
          [1., 0., 1., 1., 1.],
          [1., 0., 0., 1., 1.]]],


        [[[1., 1., 1., 0., 1.],
          [1., 1., 1., 1., 1.],
          [0., 1., 1., 1., 0.],
          [1., 1., 0., 1., 0.],
          [0., 1., 0., 1., 1.]]],


        [[[1., 0., 0., 0., 1.],
          [1., 0., 1., 1., 0.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 0., 1., 1., 0.]]],


        [[[1., 1., 1., 1., 1.],
          [0., 1., 1., 1., 0.],
          [1., 1., 1., 0., 1.],
          [0., 0., 1., 1., 1.],
          [1., 1., 0., 1., 1.]]],


        [[[1., 0., 1., 1., 1.],
          [1., 1., 0., 0., 0.],
          [1., 1., 0., 0., 0.],
          [0., 1., 1., 0., 1.],
          [1., 0., 0., 0., 1.]]],


        [[[1., 0., 1., 0., 1.],
          [0., 1., 1., 1., 1.],
          [1., 1., 0., 1., 0.],
          [1., 1., 1., 1., 1.],
          [1., 0., 0., 1., 1.]]]], device='cuda:0'))]

为了使前向传播无需修改即可正常工作,weight 属性必须存在。在 torch.nn.utils.prune 中实现的剪枝技术通过将掩码与原始参数结合,计算出剪枝后的权重版本,并将其存储在 weight 属性中。需要注意的是,这不再是 module 的参数,而只是一个属性。

print(module.weight)
tensor([[[[ 0.1529,  0.1660, -0.0469,  0.1837, -0.0438],
          [ 0.0404, -0.0000,  0.1175,  0.1763, -0.1467],
          [ 0.1738,  0.0000,  0.0000,  0.0271,  0.0964],
          [-0.0282,  0.0000,  0.0296, -0.0934,  0.0510],
          [-0.0921, -0.0000, -0.0000,  0.1327, -0.1579]]],


        [[[-0.0922, -0.0565, -0.1203,  0.0000, -0.1975],
          [ 0.1806, -0.1699,  0.1544,  0.0333, -0.0649],
          [ 0.0000,  0.0312,  0.1616,  0.0219, -0.0000],
          [ 0.0537, -0.0542,  0.0000,  0.1786,  0.0000],
          [-0.0000,  0.1155,  0.0000,  0.1016, -0.1219]]],


        [[[-0.1980, -0.0000, -0.0000,  0.0000,  0.0576],
          [ 0.0828,  0.0000, -0.0035,  0.1565, -0.0000],
          [ 0.0126, -0.1365,  0.0617, -0.0689,  0.0613],
          [-0.0417,  0.1659, -0.1185, -0.1193, -0.1193],
          [ 0.1799,  0.0000,  0.1925, -0.1651, -0.0000]]],


        [[[-0.1565, -0.1345,  0.0810,  0.0716,  0.1662],
          [-0.0000, -0.1363,  0.1061, -0.0808,  0.0000],
          [-0.0475,  0.1144, -0.1554, -0.0000,  0.0610],
          [ 0.0000, -0.0000,  0.1192,  0.1360, -0.1450],
          [-0.1068,  0.1831, -0.0000, -0.0709, -0.1935]]],


        [[[-0.1145,  0.0000, -0.0264, -0.1452,  0.0047],
          [-0.1366, -0.1697, -0.0000, -0.0000, -0.0000],
          [ 0.1999,  0.0378,  0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0313, -0.1760, -0.0000, -0.1197],
          [ 0.0006, -0.0000, -0.0000, -0.0000, -0.1373]]],


        [[[-0.1167, -0.0000, -0.1579,  0.0000, -0.0397],
          [ 0.0000,  0.0623, -0.1694,  0.1384, -0.0550],
          [-0.0767, -0.1660, -0.0000,  0.0572, -0.0000],
          [ 0.0779, -0.1641,  0.1485, -0.1468, -0.0345],
          [ 0.0418,  0.0000,  0.0000,  0.1822, -0.1586]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

最终,在每次前向传播之前,使用 PyTorch 的 forward_pre_hooks 进行剪枝。具体来说,当 module 被剪枝时(如我们在此处所做的),它将会为每个被剪枝的相关参数获取一个 forward_pre_hook。在这个例子中,由于我们目前只剪枝了名为 weight 的原始参数,因此只会存在一个钩子。

print(module._forward_pre_hooks)
OrderedDict([(15, <torch.nn.utils.prune.RandomUnstructured object at 0x7feeaaf4e650>)])

为了完整性,我们现在也可以对 bias 进行修剪,以观察 module 的参数、缓冲区、钩子和属性如何变化。只是为了尝试另一种修剪技术,这里我们使用 L1 范数来修剪 bias 中最小的 3 个条目,这在 l1_unstructured 修剪函数中实现。

prune.l1_unstructured(module, name="bias", amount=3)
Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

我们现在期望命名参数包括 weight_orig(之前的)和 bias_orig。缓冲区将包括 weight_maskbias_mask。这两个张量的修剪版本将作为模块属性存在,并且模块现在将有两个 forward_pre_hooks

print(list(module.named_parameters()))
[('weight_orig', Parameter containing:
tensor([[[[ 0.1529,  0.1660, -0.0469,  0.1837, -0.0438],
          [ 0.0404, -0.0974,  0.1175,  0.1763, -0.1467],
          [ 0.1738,  0.0374,  0.1478,  0.0271,  0.0964],
          [-0.0282,  0.1542,  0.0296, -0.0934,  0.0510],
          [-0.0921, -0.0235, -0.0812,  0.1327, -0.1579]]],


        [[[-0.0922, -0.0565, -0.1203,  0.0189, -0.1975],
          [ 0.1806, -0.1699,  0.1544,  0.0333, -0.0649],
          [ 0.1236,  0.0312,  0.1616,  0.0219, -0.0631],
          [ 0.0537, -0.0542,  0.0842,  0.1786,  0.1156],
          [-0.0874,  0.1155,  0.0358,  0.1016, -0.1219]]],


        [[[-0.1980, -0.0773, -0.1534,  0.1641,  0.0576],
          [ 0.0828,  0.0633, -0.0035,  0.1565, -0.1421],
          [ 0.0126, -0.1365,  0.0617, -0.0689,  0.0613],
          [-0.0417,  0.1659, -0.1185, -0.1193, -0.1193],
          [ 0.1799,  0.0667,  0.1925, -0.1651, -0.1984]]],


        [[[-0.1565, -0.1345,  0.0810,  0.0716,  0.1662],
          [-0.1033, -0.1363,  0.1061, -0.0808,  0.1214],
          [-0.0475,  0.1144, -0.1554, -0.1009,  0.0610],
          [ 0.0423, -0.0510,  0.1192,  0.1360, -0.1450],
          [-0.1068,  0.1831, -0.0675, -0.0709, -0.1935]]],


        [[[-0.1145,  0.0500, -0.0264, -0.1452,  0.0047],
          [-0.1366, -0.1697, -0.1101, -0.1750, -0.1273],
          [ 0.1999,  0.0378,  0.0616, -0.1865, -0.1314],
          [-0.0666,  0.0313, -0.1760, -0.0862, -0.1197],
          [ 0.0006, -0.0744, -0.0139, -0.1355, -0.1373]]],


        [[[-0.1167, -0.0685, -0.1579,  0.1677, -0.0397],
          [ 0.1721,  0.0623, -0.1694,  0.1384, -0.0550],
          [-0.0767, -0.1660, -0.1988,  0.0572, -0.0437],
          [ 0.0779, -0.1641,  0.1485, -0.1468, -0.0345],
          [ 0.0418,  0.1033,  0.1615,  0.1822, -0.1586]]]], device='cuda:0',
       requires_grad=True)), ('bias_orig', Parameter containing:
tensor([ 0.0503, -0.0860, -0.0219, -0.1497,  0.1822, -0.1468], device='cuda:0',
       requires_grad=True))]
print(list(module.named_buffers()))
[('weight_mask', tensor([[[[1., 1., 1., 1., 1.],
          [1., 0., 1., 1., 1.],
          [1., 0., 0., 1., 1.],
          [1., 0., 1., 1., 1.],
          [1., 0., 0., 1., 1.]]],


        [[[1., 1., 1., 0., 1.],
          [1., 1., 1., 1., 1.],
          [0., 1., 1., 1., 0.],
          [1., 1., 0., 1., 0.],
          [0., 1., 0., 1., 1.]]],


        [[[1., 0., 0., 0., 1.],
          [1., 0., 1., 1., 0.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 0., 1., 1., 0.]]],


        [[[1., 1., 1., 1., 1.],
          [0., 1., 1., 1., 0.],
          [1., 1., 1., 0., 1.],
          [0., 0., 1., 1., 1.],
          [1., 1., 0., 1., 1.]]],


        [[[1., 0., 1., 1., 1.],
          [1., 1., 0., 0., 0.],
          [1., 1., 0., 0., 0.],
          [0., 1., 1., 0., 1.],
          [1., 0., 0., 0., 1.]]],


        [[[1., 0., 1., 0., 1.],
          [0., 1., 1., 1., 1.],
          [1., 1., 0., 1., 0.],
          [1., 1., 1., 1., 1.],
          [1., 0., 0., 1., 1.]]]], device='cuda:0')), ('bias_mask', tensor([0., 0., 0., 1., 1., 1.], device='cuda:0'))]
print(module.bias)
tensor([ 0.0000, -0.0000, -0.0000, -0.1497,  0.1822, -0.1468], device='cuda:0',
       grad_fn=<MulBackward0>)
print(module._forward_pre_hooks)
OrderedDict([(15, <torch.nn.utils.prune.RandomUnstructured object at 0x7feeaaf4e650>), (16, <torch.nn.utils.prune.L1Unstructured object at 0x7feeaaf4e590>)])

迭代剪枝

同一个模块中的参数可以被多次修剪,多次修剪调用的效果等同于连续应用的各种掩码的组合。新掩码与旧掩码的组合由 PruningContainercompute_mask 方法处理。

例如,假设我们现在想要进一步修剪 module.weight,这次沿着张量的第 0 轴(第 0 轴对应于卷积层的输出通道,对于 conv1 来说维度为 6)进行基于通道 L2 范数的结构化修剪。这可以通过使用 ln_structured 函数来实现,其中 n=2dim=0

prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)

# As we can verify, this will zero out all the connections corresponding to
# 50% (3 out of 6) of the channels, while preserving the action of the
# previous mask.
print(module.weight)
tensor([[[[ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000, -0.0000,  0.0000],
          [-0.0000, -0.0000, -0.0000,  0.0000, -0.0000]]],


        [[[-0.0000, -0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000,  0.0000, -0.0000]]],


        [[[-0.1980, -0.0000, -0.0000,  0.0000,  0.0576],
          [ 0.0828,  0.0000, -0.0035,  0.1565, -0.0000],
          [ 0.0126, -0.1365,  0.0617, -0.0689,  0.0613],
          [-0.0417,  0.1659, -0.1185, -0.1193, -0.1193],
          [ 0.1799,  0.0000,  0.1925, -0.1651, -0.0000]]],


        [[[-0.1565, -0.1345,  0.0810,  0.0716,  0.1662],
          [-0.0000, -0.1363,  0.1061, -0.0808,  0.0000],
          [-0.0475,  0.1144, -0.1554, -0.0000,  0.0610],
          [ 0.0000, -0.0000,  0.1192,  0.1360, -0.1450],
          [-0.1068,  0.1831, -0.0000, -0.0709, -0.1935]]],


        [[[-0.0000,  0.0000, -0.0000, -0.0000,  0.0000],
          [-0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000]]],


        [[[-0.1167, -0.0000, -0.1579,  0.0000, -0.0397],
          [ 0.0000,  0.0623, -0.1694,  0.1384, -0.0550],
          [-0.0767, -0.1660, -0.0000,  0.0572, -0.0000],
          [ 0.0779, -0.1641,  0.1485, -0.1468, -0.0345],
          [ 0.0418,  0.0000,  0.0000,  0.1822, -0.1586]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

相应的钩子现在将是 torch.nn.utils.prune.PruningContainer 类型,并将存储应用于 weight 参数的剪枝历史。

for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == "weight":  # select out the correct hook
        break

print(list(hook))  # pruning history in the container
[<torch.nn.utils.prune.RandomUnstructured object at 0x7feeaaf4e650>, <torch.nn.utils.prune.LnStructured object at 0x7feeaaf4e230>]

序列化剪枝后的模型

所有相关的张量,包括掩码缓冲区和用于计算剪枝后张量的原始参数,都存储在模型的 state_dict 中,因此可以轻松地序列化并保存(如果需要)。

print(model.state_dict().keys())
odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])

移除剪枝重新参数化

要使剪枝永久生效,可以通过移除 weight_origweight_mask 的重新参数化,并移除 forward_pre_hook,我们可以使用 torch.nn.utils.prune 中的 remove 功能。需要注意的是,这并不会撤销剪枝操作,就好像剪枝从未发生过一样。它只是通过将剪枝后的 weight 参数重新分配给模型参数,使其永久生效。

在移除重新参数化之前:

print(list(module.named_parameters()))
[('weight_orig', Parameter containing:
tensor([[[[ 0.1529,  0.1660, -0.0469,  0.1837, -0.0438],
          [ 0.0404, -0.0974,  0.1175,  0.1763, -0.1467],
          [ 0.1738,  0.0374,  0.1478,  0.0271,  0.0964],
          [-0.0282,  0.1542,  0.0296, -0.0934,  0.0510],
          [-0.0921, -0.0235, -0.0812,  0.1327, -0.1579]]],


        [[[-0.0922, -0.0565, -0.1203,  0.0189, -0.1975],
          [ 0.1806, -0.1699,  0.1544,  0.0333, -0.0649],
          [ 0.1236,  0.0312,  0.1616,  0.0219, -0.0631],
          [ 0.0537, -0.0542,  0.0842,  0.1786,  0.1156],
          [-0.0874,  0.1155,  0.0358,  0.1016, -0.1219]]],


        [[[-0.1980, -0.0773, -0.1534,  0.1641,  0.0576],
          [ 0.0828,  0.0633, -0.0035,  0.1565, -0.1421],
          [ 0.0126, -0.1365,  0.0617, -0.0689,  0.0613],
          [-0.0417,  0.1659, -0.1185, -0.1193, -0.1193],
          [ 0.1799,  0.0667,  0.1925, -0.1651, -0.1984]]],


        [[[-0.1565, -0.1345,  0.0810,  0.0716,  0.1662],
          [-0.1033, -0.1363,  0.1061, -0.0808,  0.1214],
          [-0.0475,  0.1144, -0.1554, -0.1009,  0.0610],
          [ 0.0423, -0.0510,  0.1192,  0.1360, -0.1450],
          [-0.1068,  0.1831, -0.0675, -0.0709, -0.1935]]],


        [[[-0.1145,  0.0500, -0.0264, -0.1452,  0.0047],
          [-0.1366, -0.1697, -0.1101, -0.1750, -0.1273],
          [ 0.1999,  0.0378,  0.0616, -0.1865, -0.1314],
          [-0.0666,  0.0313, -0.1760, -0.0862, -0.1197],
          [ 0.0006, -0.0744, -0.0139, -0.1355, -0.1373]]],


        [[[-0.1167, -0.0685, -0.1579,  0.1677, -0.0397],
          [ 0.1721,  0.0623, -0.1694,  0.1384, -0.0550],
          [-0.0767, -0.1660, -0.1988,  0.0572, -0.0437],
          [ 0.0779, -0.1641,  0.1485, -0.1468, -0.0345],
          [ 0.0418,  0.1033,  0.1615,  0.1822, -0.1586]]]], device='cuda:0',
       requires_grad=True)), ('bias_orig', Parameter containing:
tensor([ 0.0503, -0.0860, -0.0219, -0.1497,  0.1822, -0.1468], device='cuda:0',
       requires_grad=True))]
print(list(module.named_buffers()))
[('weight_mask', tensor([[[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[1., 0., 0., 0., 1.],
          [1., 0., 1., 1., 0.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 0., 1., 1., 0.]]],


        [[[1., 1., 1., 1., 1.],
          [0., 1., 1., 1., 0.],
          [1., 1., 1., 0., 1.],
          [0., 0., 1., 1., 1.],
          [1., 1., 0., 1., 1.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[1., 0., 1., 0., 1.],
          [0., 1., 1., 1., 1.],
          [1., 1., 0., 1., 0.],
          [1., 1., 1., 1., 1.],
          [1., 0., 0., 1., 1.]]]], device='cuda:0')), ('bias_mask', tensor([0., 0., 0., 1., 1., 1.], device='cuda:0'))]
print(module.weight)
tensor([[[[ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000, -0.0000,  0.0000],
          [-0.0000, -0.0000, -0.0000,  0.0000, -0.0000]]],


        [[[-0.0000, -0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000,  0.0000, -0.0000]]],


        [[[-0.1980, -0.0000, -0.0000,  0.0000,  0.0576],
          [ 0.0828,  0.0000, -0.0035,  0.1565, -0.0000],
          [ 0.0126, -0.1365,  0.0617, -0.0689,  0.0613],
          [-0.0417,  0.1659, -0.1185, -0.1193, -0.1193],
          [ 0.1799,  0.0000,  0.1925, -0.1651, -0.0000]]],


        [[[-0.1565, -0.1345,  0.0810,  0.0716,  0.1662],
          [-0.0000, -0.1363,  0.1061, -0.0808,  0.0000],
          [-0.0475,  0.1144, -0.1554, -0.0000,  0.0610],
          [ 0.0000, -0.0000,  0.1192,  0.1360, -0.1450],
          [-0.1068,  0.1831, -0.0000, -0.0709, -0.1935]]],


        [[[-0.0000,  0.0000, -0.0000, -0.0000,  0.0000],
          [-0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000]]],


        [[[-0.1167, -0.0000, -0.1579,  0.0000, -0.0397],
          [ 0.0000,  0.0623, -0.1694,  0.1384, -0.0550],
          [-0.0767, -0.1660, -0.0000,  0.0572, -0.0000],
          [ 0.0779, -0.1641,  0.1485, -0.1468, -0.0345],
          [ 0.0418,  0.0000,  0.0000,  0.1822, -0.1586]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

在移除重新参数化之后:

prune.remove(module, 'weight')
print(list(module.named_parameters()))
[('bias_orig', Parameter containing:
tensor([ 0.0503, -0.0860, -0.0219, -0.1497,  0.1822, -0.1468], device='cuda:0',
       requires_grad=True)), ('weight', Parameter containing:
tensor([[[[ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000, -0.0000,  0.0000],
          [-0.0000, -0.0000, -0.0000,  0.0000, -0.0000]]],


        [[[-0.0000, -0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000,  0.0000, -0.0000]]],


        [[[-0.1980, -0.0000, -0.0000,  0.0000,  0.0576],
          [ 0.0828,  0.0000, -0.0035,  0.1565, -0.0000],
          [ 0.0126, -0.1365,  0.0617, -0.0689,  0.0613],
          [-0.0417,  0.1659, -0.1185, -0.1193, -0.1193],
          [ 0.1799,  0.0000,  0.1925, -0.1651, -0.0000]]],


        [[[-0.1565, -0.1345,  0.0810,  0.0716,  0.1662],
          [-0.0000, -0.1363,  0.1061, -0.0808,  0.0000],
          [-0.0475,  0.1144, -0.1554, -0.0000,  0.0610],
          [ 0.0000, -0.0000,  0.1192,  0.1360, -0.1450],
          [-0.1068,  0.1831, -0.0000, -0.0709, -0.1935]]],


        [[[-0.0000,  0.0000, -0.0000, -0.0000,  0.0000],
          [-0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000]]],


        [[[-0.1167, -0.0000, -0.1579,  0.0000, -0.0397],
          [ 0.0000,  0.0623, -0.1694,  0.1384, -0.0550],
          [-0.0767, -0.1660, -0.0000,  0.0572, -0.0000],
          [ 0.0779, -0.1641,  0.1485, -0.1468, -0.0345],
          [ 0.0418,  0.0000,  0.0000,  0.1822, -0.1586]]]], device='cuda:0',
       requires_grad=True))]
print(list(module.named_buffers()))
[('bias_mask', tensor([0., 0., 0., 1., 1., 1.], device='cuda:0'))]

在模型中剪枝多个参数

通过指定所需的剪枝技术和参数,我们可以轻松地修剪网络中的多个张量,例如根据它们的类型进行修剪,正如我们将在本示例中看到的那样。

new_model = LeNet()
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist
dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])

全局剪枝

到目前为止,我们只讨论了通常被称为“局部”剪枝的方法,即通过比较每个张量中的各个条目的统计信息(如权重大小、激活值、梯度等),逐一剪枝模型中的张量。然而,一种更为常见且可能更强大的技术是一次性剪枝整个模型,例如移除整个模型中最低的20%连接,而不是逐层移除每层中最低的20%连接。这可能会导致每层的剪枝比例不同。让我们看看如何使用torch.nn.utils.prune中的global_unstructured来实现这一点。

model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

现在我们可以检查每个剪枝参数中引入的稀疏性,每一层的稀疏性不会都等于20%。然而,全局稀疏性将(大约)为20%。

print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)
Sparsity in conv1.weight: 4.67%
Sparsity in conv2.weight: 13.92%
Sparsity in fc1.weight: 22.16%
Sparsity in fc2.weight: 12.10%
Sparsity in fc3.weight: 11.31%
Global sparsity: 20.00%

扩展 torch.nn.utils.prune 以支持自定义剪枝函数

要实现您自己的剪枝函数,可以通过子类化 BasePruningMethod 基类来扩展 nn.utils.prune 模块,就像其他所有剪枝方法所做的那样。基类为您实现了以下方法:__call__apply_maskapplypruneremove。除了一些特殊情况外,您不需要为新的剪枝技术重新实现这些方法。然而,您需要实现 __init__(构造函数)和 compute_mask(根据您的剪枝技术逻辑计算给定张量的掩码的指令)。此外,您还需要指定该剪枝技术实现的剪枝类型(支持的选项有 globalstructuredunstructured)。这是为了确定在迭代应用剪枝时如何组合掩码。换句话说,当剪枝一个已经剪枝过的参数时,当前的剪枝技术应作用于该参数的未剪枝部分。指定 PRUNING_TYPE 将使 PruningContainer(处理剪枝掩码的迭代应用)能够正确识别要剪枝的参数部分。

假设您想要实现一种剪枝技术,该技术会剪掉张量中每隔一个的条目(或者,如果张量之前已经被剪枝过,则剪掉张量中剩余的未剪枝部分)。这将属于 PRUNING_TYPE='unstructured',因为它作用于层中的单个连接,而不是整个单元/通道('structured'),或者跨不同参数('global')。

classFooBarPruningMethod(prune.BasePruningMethod):
"""Prune every other entry in a tensor
    """
    PRUNING_TYPE = 'unstructured'

    defcompute_mask(self, t, default_mask):
        mask = default_mask.clone()
        mask.view(-1)[::2] = 0
        return mask

现在,要将此应用于 nn.Module 中的参数,您还需要提供一个简单的函数来实例化该方法并应用它。

deffoobar_unstructured(module, name):
"""Prunes tensor corresponding to parameter called `name` in `module`
    by removing every other entry in the tensors.
    Modifies module in place (and also return the modified module)
    by:
    1) adding a named buffer called `name+'_mask'` corresponding to the
    binary mask applied to the parameter `name` by the pruning method.
    The parameter `name` is replaced by its pruned version, while the
    original (unpruned) parameter is stored in a new parameter named
    `name+'_orig'`.

    Args:
        module (nn.Module): module containing the tensor to prune
        name (string): parameter name within `module` on which pruning
                will act.

    Returns:
        module (nn.Module): modified (i.e. pruned) version of the input
            module

    Examples:
        >>> m = nn.Linear(3, 4)
        >>> foobar_unstructured(m, name='bias')
    """
    FooBarPruningMethod.apply(module, name)
    return module

让我们来试试吧!

model = LeNet()
foobar_unstructured(model.fc3, name='bias')

print(model.fc3.bias_mask)
tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])
本页目录