Shortcuts

剪枝教程

Created On: Jul 22, 2019 | Last Updated: Nov 02, 2023 | Last Verified: Nov 05, 2024

作者: Michela Paganini

最新的深度学习技术依赖于难以部署的过参数化模型。相反,生物神经网络以高效的稀疏连接而闻名。为了减少模型中的参数数量,从而降低内存、电池和硬件消耗而不牺牲准确性,寻找最佳压缩技术非常重要。这反过来允许您在设备上部署轻量级模型,并通过私有设备计算保证隐私。在研究方面,剪枝用于研究过参数化和欠参数化网络的学习动态之间的差异,研究幸运稀疏子网络和初始化(”彩票票理论”)作为破坏性的神经架构搜索技术等问题。

在本教程中,您将学习如何使用 torch.nn.utils.prune 对神经网络进行稀疏化,并了解如何扩展它以实现您自己的剪枝技术。

需求

"torch>=1.4.0a0+8e8a5e0"

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

创建模型

在本教程中,我们使用1998年LeCun等人的`LeNet <http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf>`_架构。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

检查模块

让我们检查LeNet模型中的(未剪枝的)``conv1``层。它目前只包含两个参数``weight``和``bias``,暂时没有缓冲区。

[('weight', Parameter containing:
tensor([[[[ 0.0713,  0.0620, -0.0759,  0.0741,  0.1250],
          [-0.1836, -0.1390,  0.0297,  0.0674,  0.1307],
          [ 0.0881,  0.0229, -0.1275,  0.0220, -0.0697],
          [ 0.0231,  0.1802,  0.1602,  0.0758,  0.0954],
          [-0.1503,  0.0802, -0.0954, -0.0521, -0.0328]]],


        [[[-0.1326,  0.0053, -0.0103,  0.1243, -0.0702],
          [-0.0951,  0.0987,  0.0470, -0.1065,  0.0745],
          [ 0.1494,  0.1755, -0.1427, -0.0911,  0.1237],
          [ 0.1055,  0.0347, -0.0191,  0.0835,  0.0617],
          [-0.1599, -0.0306, -0.0636, -0.0443,  0.1766]]],


        [[[ 0.1243,  0.0394, -0.1804, -0.1733, -0.1859],
          [ 0.1313, -0.0756,  0.1669,  0.0347, -0.1517],
          [ 0.1714, -0.1959, -0.1689,  0.0673,  0.0251],
          [-0.1600, -0.0607,  0.1254, -0.1733,  0.0873],
          [ 0.1968,  0.1733,  0.1670,  0.1426, -0.1091]]],


        [[[ 0.0248, -0.1721,  0.1796, -0.0456, -0.0657],
          [-0.0953, -0.1104,  0.0952,  0.0531, -0.0051],
          [-0.0968,  0.1052, -0.0977,  0.0518,  0.1492],
          [ 0.0673,  0.1428, -0.1279,  0.0764, -0.0350],
          [-0.1370, -0.0935,  0.1949,  0.0348,  0.0239]]],


        [[[-0.1931,  0.0271,  0.0981, -0.1367,  0.0790],
          [-0.0463, -0.1023, -0.0496, -0.0704, -0.1147],
          [-0.1976, -0.0313, -0.0585,  0.1317, -0.1063],
          [-0.0924, -0.0234, -0.1526,  0.0646, -0.1502],
          [-0.1971, -0.0557, -0.0128, -0.0448,  0.0844]]],


        [[[ 0.1840, -0.1198, -0.0209,  0.0126, -0.1531],
          [ 0.0700, -0.0692,  0.1178,  0.1409, -0.0063],
          [ 0.0746, -0.0518, -0.1475, -0.1488, -0.1001],
          [ 0.0157, -0.1287, -0.0178, -0.0382,  0.1183],
          [-0.0451,  0.1595, -0.0645, -0.1487,  0.1408]]]], device='cuda:0',
       requires_grad=True)), ('bias', Parameter containing:
tensor([-0.0052, -0.1768, -0.1013, -0.1593,  0.1128,  0.1546], device='cuda:0',
       requires_grad=True))]
print(list(module.named_buffers()))
[]

剪枝模块

要剪枝一个模块(在本例中是LeNet架构的``conv1``层),首先从``torch.nn.utils.prune``可用的剪枝技术中选择一种(或者通过继承``BasePruningMethod``来`实现 <#extending-torch-nn-utils-pruning-with-custom-pruning-functions>`_ 自定义剪枝技术)。然后需指定模块和模块中要剪枝的参数名称。最后,使用所选剪枝技术所需的合适关键字参数,指定剪枝参数。

在此示例中,我们将随机剪枝``conv1``层中参数``weight``的30%连接。模块作为函数的第一个参数传递;``name``通过其字符串标识符标识模块中的参数;而``amount``则指示剪枝连接的百分比(如果是0到1之间的小数)或剪枝连接的绝对数量(如果是非负整数)。

prune.random_unstructured(module, name="weight", amount=0.3)
Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

剪枝通过从参数中移除``weight``并用称为``weight_orig``的新参数代替它来执行(即为初始参数``name``添加``”_orig”``后缀)。``weight_orig``储存未剪枝版本的张量。``bias``未被剪枝,因此将保持完整。

print(list(module.named_parameters()))
[('bias', Parameter containing:
tensor([-0.0052, -0.1768, -0.1013, -0.1593,  0.1128,  0.1546], device='cuda:0',
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.0713,  0.0620, -0.0759,  0.0741,  0.1250],
          [-0.1836, -0.1390,  0.0297,  0.0674,  0.1307],
          [ 0.0881,  0.0229, -0.1275,  0.0220, -0.0697],
          [ 0.0231,  0.1802,  0.1602,  0.0758,  0.0954],
          [-0.1503,  0.0802, -0.0954, -0.0521, -0.0328]]],


        [[[-0.1326,  0.0053, -0.0103,  0.1243, -0.0702],
          [-0.0951,  0.0987,  0.0470, -0.1065,  0.0745],
          [ 0.1494,  0.1755, -0.1427, -0.0911,  0.1237],
          [ 0.1055,  0.0347, -0.0191,  0.0835,  0.0617],
          [-0.1599, -0.0306, -0.0636, -0.0443,  0.1766]]],


        [[[ 0.1243,  0.0394, -0.1804, -0.1733, -0.1859],
          [ 0.1313, -0.0756,  0.1669,  0.0347, -0.1517],
          [ 0.1714, -0.1959, -0.1689,  0.0673,  0.0251],
          [-0.1600, -0.0607,  0.1254, -0.1733,  0.0873],
          [ 0.1968,  0.1733,  0.1670,  0.1426, -0.1091]]],


        [[[ 0.0248, -0.1721,  0.1796, -0.0456, -0.0657],
          [-0.0953, -0.1104,  0.0952,  0.0531, -0.0051],
          [-0.0968,  0.1052, -0.0977,  0.0518,  0.1492],
          [ 0.0673,  0.1428, -0.1279,  0.0764, -0.0350],
          [-0.1370, -0.0935,  0.1949,  0.0348,  0.0239]]],


        [[[-0.1931,  0.0271,  0.0981, -0.1367,  0.0790],
          [-0.0463, -0.1023, -0.0496, -0.0704, -0.1147],
          [-0.1976, -0.0313, -0.0585,  0.1317, -0.1063],
          [-0.0924, -0.0234, -0.1526,  0.0646, -0.1502],
          [-0.1971, -0.0557, -0.0128, -0.0448,  0.0844]]],


        [[[ 0.1840, -0.1198, -0.0209,  0.0126, -0.1531],
          [ 0.0700, -0.0692,  0.1178,  0.1409, -0.0063],
          [ 0.0746, -0.0518, -0.1475, -0.1488, -0.1001],
          [ 0.0157, -0.1287, -0.0178, -0.0382,  0.1183],
          [-0.0451,  0.1595, -0.0645, -0.1487,  0.1408]]]], device='cuda:0',
       requires_grad=True))]

由上面选择的剪枝技术生成的剪枝掩码会被保存为模块缓冲区,名称为``weight_mask``(即为初始参数``name``添加``”_mask”``后缀)。

print(list(module.named_buffers()))
[('weight_mask', tensor([[[[1., 1., 1., 0., 0.],
          [0., 1., 1., 1., 1.],
          [1., 0., 0., 1., 1.],
          [1., 0., 1., 1., 1.],
          [1., 0., 1., 1., 0.]]],


        [[[0., 1., 1., 0., 0.],
          [1., 0., 0., 1., 1.],
          [0., 1., 1., 1., 1.],
          [1., 1., 0., 1., 1.],
          [1., 0., 0., 1., 1.]]],


        [[[0., 1., 0., 1., 0.],
          [1., 1., 0., 1., 0.],
          [1., 1., 1., 1., 1.],
          [1., 1., 0., 1., 1.],
          [0., 1., 1., 1., 1.]]],


        [[[1., 1., 1., 1., 1.],
          [0., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 0., 1., 1.],
          [1., 1., 1., 0., 1.]]],


        [[[1., 1., 1., 0., 0.],
          [0., 0., 1., 1., 1.],
          [1., 0., 1., 1., 1.],
          [1., 1., 1., 1., 0.],
          [1., 1., 0., 1., 0.]]],


        [[[1., 0., 0., 1., 1.],
          [0., 1., 1., 0., 0.],
          [1., 0., 1., 1., 1.],
          [1., 0., 1., 1., 0.],
          [1., 0., 0., 1., 1.]]]], device='cuda:0'))]

为了使前向传播能够正常工作,``weight``属性需要存在。在``torch.nn.utils.prune``中实现的剪枝技术通过结合掩码与初始参数来计算剪枝后的权重,并将其存储在``weight``属性中。注意,这已不再是模块的参数,而只是一个属性。

tensor([[[[ 0.0713,  0.0620, -0.0759,  0.0000,  0.0000],
          [-0.0000, -0.1390,  0.0297,  0.0674,  0.1307],
          [ 0.0881,  0.0000, -0.0000,  0.0220, -0.0697],
          [ 0.0231,  0.0000,  0.1602,  0.0758,  0.0954],
          [-0.1503,  0.0000, -0.0954, -0.0521, -0.0000]]],


        [[[-0.0000,  0.0053, -0.0103,  0.0000, -0.0000],
          [-0.0951,  0.0000,  0.0000, -0.1065,  0.0745],
          [ 0.0000,  0.1755, -0.1427, -0.0911,  0.1237],
          [ 0.1055,  0.0347, -0.0000,  0.0835,  0.0617],
          [-0.1599, -0.0000, -0.0000, -0.0443,  0.1766]]],


        [[[ 0.0000,  0.0394, -0.0000, -0.1733, -0.0000],
          [ 0.1313, -0.0756,  0.0000,  0.0347, -0.0000],
          [ 0.1714, -0.1959, -0.1689,  0.0673,  0.0251],
          [-0.1600, -0.0607,  0.0000, -0.1733,  0.0873],
          [ 0.0000,  0.1733,  0.1670,  0.1426, -0.1091]]],


        [[[ 0.0248, -0.1721,  0.1796, -0.0456, -0.0657],
          [-0.0000, -0.1104,  0.0952,  0.0531, -0.0051],
          [-0.0968,  0.1052, -0.0977,  0.0518,  0.1492],
          [ 0.0673,  0.1428, -0.0000,  0.0764, -0.0350],
          [-0.1370, -0.0935,  0.1949,  0.0000,  0.0239]]],


        [[[-0.1931,  0.0271,  0.0981, -0.0000,  0.0000],
          [-0.0000, -0.0000, -0.0496, -0.0704, -0.1147],
          [-0.1976, -0.0000, -0.0585,  0.1317, -0.1063],
          [-0.0924, -0.0234, -0.1526,  0.0646, -0.0000],
          [-0.1971, -0.0557, -0.0000, -0.0448,  0.0000]]],


        [[[ 0.1840, -0.0000, -0.0000,  0.0126, -0.1531],
          [ 0.0000, -0.0692,  0.1178,  0.0000, -0.0000],
          [ 0.0746, -0.0000, -0.1475, -0.1488, -0.1001],
          [ 0.0157, -0.0000, -0.0178, -0.0382,  0.0000],
          [-0.0451,  0.0000, -0.0000, -0.1487,  0.1408]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

最后,剪枝在每次前向传播之前通过PyTorch的``forward_pre_hooks``进行应用。具体来说,当模块被剪枝时,如我们已完成的那样,它会为每个与其相关的被剪枝参数获取一个``forward_pre_hook``。在这种情况下,由于我们到目前为止只剪枝了称为``weight``的原始参数,所以只会存在一个钩子。

print(module._forward_pre_hooks)
OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f1f9d9bf940>)])

为了完整性,我们现在也可以剪枝``bias``,以观察模块的参数、缓冲区、钩子和属性的变化。为了尝试另一种剪枝技术,我们在此通过L1范数剪枝``bias``中最小的3个条目,使用``l1_unstructured``剪枝函数实现。

prune.l1_unstructured(module, name="bias", amount=3)
Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

此时,我们期望命名参数包括``weight_orig``(之前的)和``bias_orig``。缓冲区将包括``weight_mask``和``bias_mask``。两个张量的剪枝版本将作为模块属性存在,模块现在将有两个``forward_pre_hooks``。

print(list(module.named_parameters()))
[('weight_orig', Parameter containing:
tensor([[[[ 0.0713,  0.0620, -0.0759,  0.0741,  0.1250],
          [-0.1836, -0.1390,  0.0297,  0.0674,  0.1307],
          [ 0.0881,  0.0229, -0.1275,  0.0220, -0.0697],
          [ 0.0231,  0.1802,  0.1602,  0.0758,  0.0954],
          [-0.1503,  0.0802, -0.0954, -0.0521, -0.0328]]],


        [[[-0.1326,  0.0053, -0.0103,  0.1243, -0.0702],
          [-0.0951,  0.0987,  0.0470, -0.1065,  0.0745],
          [ 0.1494,  0.1755, -0.1427, -0.0911,  0.1237],
          [ 0.1055,  0.0347, -0.0191,  0.0835,  0.0617],
          [-0.1599, -0.0306, -0.0636, -0.0443,  0.1766]]],


        [[[ 0.1243,  0.0394, -0.1804, -0.1733, -0.1859],
          [ 0.1313, -0.0756,  0.1669,  0.0347, -0.1517],
          [ 0.1714, -0.1959, -0.1689,  0.0673,  0.0251],
          [-0.1600, -0.0607,  0.1254, -0.1733,  0.0873],
          [ 0.1968,  0.1733,  0.1670,  0.1426, -0.1091]]],


        [[[ 0.0248, -0.1721,  0.1796, -0.0456, -0.0657],
          [-0.0953, -0.1104,  0.0952,  0.0531, -0.0051],
          [-0.0968,  0.1052, -0.0977,  0.0518,  0.1492],
          [ 0.0673,  0.1428, -0.1279,  0.0764, -0.0350],
          [-0.1370, -0.0935,  0.1949,  0.0348,  0.0239]]],


        [[[-0.1931,  0.0271,  0.0981, -0.1367,  0.0790],
          [-0.0463, -0.1023, -0.0496, -0.0704, -0.1147],
          [-0.1976, -0.0313, -0.0585,  0.1317, -0.1063],
          [-0.0924, -0.0234, -0.1526,  0.0646, -0.1502],
          [-0.1971, -0.0557, -0.0128, -0.0448,  0.0844]]],


        [[[ 0.1840, -0.1198, -0.0209,  0.0126, -0.1531],
          [ 0.0700, -0.0692,  0.1178,  0.1409, -0.0063],
          [ 0.0746, -0.0518, -0.1475, -0.1488, -0.1001],
          [ 0.0157, -0.1287, -0.0178, -0.0382,  0.1183],
          [-0.0451,  0.1595, -0.0645, -0.1487,  0.1408]]]], device='cuda:0',
       requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.0052, -0.1768, -0.1013, -0.1593,  0.1128,  0.1546], device='cuda:0',
       requires_grad=True))]
print(list(module.named_buffers()))
[('weight_mask', tensor([[[[1., 1., 1., 0., 0.],
          [0., 1., 1., 1., 1.],
          [1., 0., 0., 1., 1.],
          [1., 0., 1., 1., 1.],
          [1., 0., 1., 1., 0.]]],


        [[[0., 1., 1., 0., 0.],
          [1., 0., 0., 1., 1.],
          [0., 1., 1., 1., 1.],
          [1., 1., 0., 1., 1.],
          [1., 0., 0., 1., 1.]]],


        [[[0., 1., 0., 1., 0.],
          [1., 1., 0., 1., 0.],
          [1., 1., 1., 1., 1.],
          [1., 1., 0., 1., 1.],
          [0., 1., 1., 1., 1.]]],


        [[[1., 1., 1., 1., 1.],
          [0., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 0., 1., 1.],
          [1., 1., 1., 0., 1.]]],


        [[[1., 1., 1., 0., 0.],
          [0., 0., 1., 1., 1.],
          [1., 0., 1., 1., 1.],
          [1., 1., 1., 1., 0.],
          [1., 1., 0., 1., 0.]]],


        [[[1., 0., 0., 1., 1.],
          [0., 1., 1., 0., 0.],
          [1., 0., 1., 1., 1.],
          [1., 0., 1., 1., 0.],
          [1., 0., 0., 1., 1.]]]], device='cuda:0')), ('bias_mask', tensor([0., 1., 0., 1., 0., 1.], device='cuda:0'))]
print(module.bias)
tensor([-0.0000, -0.1768, -0.0000, -0.1593,  0.0000,  0.1546], device='cuda:0',
       grad_fn=<MulBackward0>)
print(module._forward_pre_hooks)
OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f1f9d9bf940>), (1, <torch.nn.utils.prune.L1Unstructured object at 0x7f1f9d9bcaf0>)])

迭代剪枝

一个模块中的相同参数可以多次剪枝,这些剪枝调用的效果等同于系列中应用的各种掩码的组合。新掩码与旧掩码的组合由``PruningContainer``的``compute_mask``方法处理。

例如,我们现在想进一步剪枝``module.weight``,这次是沿着张量的第0轴进行结构化剪枝(第0轴对应卷积层的输出通道以及``conv1``层的维度为6),基于通道的L2范数。这可以通过``ln_structured``函数实现,设置``n=2``和``dim=0``。

prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)

# As we can verify, this will zero out all the connections corresponding to
# 50% (3 out of 6) of the channels, while preserving the action of the
# previous mask.
print(module.weight)
tensor([[[[ 0.0000,  0.0000, -0.0000,  0.0000,  0.0000],
          [-0.0000, -0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000, -0.0000, -0.0000]]],


        [[[-0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000,  0.0000,  0.0000],
          [-0.0000, -0.0000, -0.0000, -0.0000,  0.0000]]],


        [[[ 0.0000,  0.0394, -0.0000, -0.1733, -0.0000],
          [ 0.1313, -0.0756,  0.0000,  0.0347, -0.0000],
          [ 0.1714, -0.1959, -0.1689,  0.0673,  0.0251],
          [-0.1600, -0.0607,  0.0000, -0.1733,  0.0873],
          [ 0.0000,  0.1733,  0.1670,  0.1426, -0.1091]]],


        [[[ 0.0248, -0.1721,  0.1796, -0.0456, -0.0657],
          [-0.0000, -0.1104,  0.0952,  0.0531, -0.0051],
          [-0.0968,  0.1052, -0.0977,  0.0518,  0.1492],
          [ 0.0673,  0.1428, -0.0000,  0.0764, -0.0350],
          [-0.1370, -0.0935,  0.1949,  0.0000,  0.0239]]],


        [[[-0.1931,  0.0271,  0.0981, -0.0000,  0.0000],
          [-0.0000, -0.0000, -0.0496, -0.0704, -0.1147],
          [-0.1976, -0.0000, -0.0585,  0.1317, -0.1063],
          [-0.0924, -0.0234, -0.1526,  0.0646, -0.0000],
          [-0.1971, -0.0557, -0.0000, -0.0448,  0.0000]]],


        [[[ 0.0000, -0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000, -0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000, -0.0000,  0.0000]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

相应的钩子类型现在将是``torch.nn.utils.prune.PruningContainer``,并存储``weight``参数的剪枝历史。

for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == "weight":  # select out the correct hook
        break

print(list(hook))  # pruning history in the container
[<torch.nn.utils.prune.RandomUnstructured object at 0x7f1f9d9bf940>, <torch.nn.utils.prune.LnStructured object at 0x7f1f9d9bcb50>]

序列化剪枝模型

所有相关张量,包括用于计算剪枝张量的掩码缓冲区和原始参数,都存储在模型的``state_dict``中,因此可以轻松序列化并保存(如果需要)。

print(model.state_dict().keys())
odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])

移除剪枝重新参数化

为了使剪枝永久化,移除以``weight_orig``和``weight_mask``为形式的重新参数化,并移除``forward_pre_hook``,可以使用``torch.nn.utils.prune``的``remove``功能。注意,这不会撤销剪枝,使其看似从未发生过。相反,它仅使其永久化,通过将剪枝后的参数``weight``分配为模型参数。

移除重新参数化之前:

print(list(module.named_parameters()))
[('weight_orig', Parameter containing:
tensor([[[[ 0.0713,  0.0620, -0.0759,  0.0741,  0.1250],
          [-0.1836, -0.1390,  0.0297,  0.0674,  0.1307],
          [ 0.0881,  0.0229, -0.1275,  0.0220, -0.0697],
          [ 0.0231,  0.1802,  0.1602,  0.0758,  0.0954],
          [-0.1503,  0.0802, -0.0954, -0.0521, -0.0328]]],


        [[[-0.1326,  0.0053, -0.0103,  0.1243, -0.0702],
          [-0.0951,  0.0987,  0.0470, -0.1065,  0.0745],
          [ 0.1494,  0.1755, -0.1427, -0.0911,  0.1237],
          [ 0.1055,  0.0347, -0.0191,  0.0835,  0.0617],
          [-0.1599, -0.0306, -0.0636, -0.0443,  0.1766]]],


        [[[ 0.1243,  0.0394, -0.1804, -0.1733, -0.1859],
          [ 0.1313, -0.0756,  0.1669,  0.0347, -0.1517],
          [ 0.1714, -0.1959, -0.1689,  0.0673,  0.0251],
          [-0.1600, -0.0607,  0.1254, -0.1733,  0.0873],
          [ 0.1968,  0.1733,  0.1670,  0.1426, -0.1091]]],


        [[[ 0.0248, -0.1721,  0.1796, -0.0456, -0.0657],
          [-0.0953, -0.1104,  0.0952,  0.0531, -0.0051],
          [-0.0968,  0.1052, -0.0977,  0.0518,  0.1492],
          [ 0.0673,  0.1428, -0.1279,  0.0764, -0.0350],
          [-0.1370, -0.0935,  0.1949,  0.0348,  0.0239]]],


        [[[-0.1931,  0.0271,  0.0981, -0.1367,  0.0790],
          [-0.0463, -0.1023, -0.0496, -0.0704, -0.1147],
          [-0.1976, -0.0313, -0.0585,  0.1317, -0.1063],
          [-0.0924, -0.0234, -0.1526,  0.0646, -0.1502],
          [-0.1971, -0.0557, -0.0128, -0.0448,  0.0844]]],


        [[[ 0.1840, -0.1198, -0.0209,  0.0126, -0.1531],
          [ 0.0700, -0.0692,  0.1178,  0.1409, -0.0063],
          [ 0.0746, -0.0518, -0.1475, -0.1488, -0.1001],
          [ 0.0157, -0.1287, -0.0178, -0.0382,  0.1183],
          [-0.0451,  0.1595, -0.0645, -0.1487,  0.1408]]]], device='cuda:0',
       requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.0052, -0.1768, -0.1013, -0.1593,  0.1128,  0.1546], device='cuda:0',
       requires_grad=True))]
print(list(module.named_buffers()))
[('weight_mask', tensor([[[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 1., 0., 1., 0.],
          [1., 1., 0., 1., 0.],
          [1., 1., 1., 1., 1.],
          [1., 1., 0., 1., 1.],
          [0., 1., 1., 1., 1.]]],


        [[[1., 1., 1., 1., 1.],
          [0., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 0., 1., 1.],
          [1., 1., 1., 0., 1.]]],


        [[[1., 1., 1., 0., 0.],
          [0., 0., 1., 1., 1.],
          [1., 0., 1., 1., 1.],
          [1., 1., 1., 1., 0.],
          [1., 1., 0., 1., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]]], device='cuda:0')), ('bias_mask', tensor([0., 1., 0., 1., 0., 1.], device='cuda:0'))]
tensor([[[[ 0.0000,  0.0000, -0.0000,  0.0000,  0.0000],
          [-0.0000, -0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000, -0.0000, -0.0000]]],


        [[[-0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000,  0.0000,  0.0000],
          [-0.0000, -0.0000, -0.0000, -0.0000,  0.0000]]],


        [[[ 0.0000,  0.0394, -0.0000, -0.1733, -0.0000],
          [ 0.1313, -0.0756,  0.0000,  0.0347, -0.0000],
          [ 0.1714, -0.1959, -0.1689,  0.0673,  0.0251],
          [-0.1600, -0.0607,  0.0000, -0.1733,  0.0873],
          [ 0.0000,  0.1733,  0.1670,  0.1426, -0.1091]]],


        [[[ 0.0248, -0.1721,  0.1796, -0.0456, -0.0657],
          [-0.0000, -0.1104,  0.0952,  0.0531, -0.0051],
          [-0.0968,  0.1052, -0.0977,  0.0518,  0.1492],
          [ 0.0673,  0.1428, -0.0000,  0.0764, -0.0350],
          [-0.1370, -0.0935,  0.1949,  0.0000,  0.0239]]],


        [[[-0.1931,  0.0271,  0.0981, -0.0000,  0.0000],
          [-0.0000, -0.0000, -0.0496, -0.0704, -0.1147],
          [-0.1976, -0.0000, -0.0585,  0.1317, -0.1063],
          [-0.0924, -0.0234, -0.1526,  0.0646, -0.0000],
          [-0.1971, -0.0557, -0.0000, -0.0448,  0.0000]]],


        [[[ 0.0000, -0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000, -0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000, -0.0000,  0.0000]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

移除重新参数化之后:

prune.remove(module, 'weight')
print(list(module.named_parameters()))
[('bias_orig', Parameter containing:
tensor([-0.0052, -0.1768, -0.1013, -0.1593,  0.1128,  0.1546], device='cuda:0',
       requires_grad=True)), ('weight', Parameter containing:
tensor([[[[ 0.0000,  0.0000, -0.0000,  0.0000,  0.0000],
          [-0.0000, -0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000, -0.0000, -0.0000]]],


        [[[-0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000,  0.0000,  0.0000],
          [-0.0000, -0.0000, -0.0000, -0.0000,  0.0000]]],


        [[[ 0.0000,  0.0394, -0.0000, -0.1733, -0.0000],
          [ 0.1313, -0.0756,  0.0000,  0.0347, -0.0000],
          [ 0.1714, -0.1959, -0.1689,  0.0673,  0.0251],
          [-0.1600, -0.0607,  0.0000, -0.1733,  0.0873],
          [ 0.0000,  0.1733,  0.1670,  0.1426, -0.1091]]],


        [[[ 0.0248, -0.1721,  0.1796, -0.0456, -0.0657],
          [-0.0000, -0.1104,  0.0952,  0.0531, -0.0051],
          [-0.0968,  0.1052, -0.0977,  0.0518,  0.1492],
          [ 0.0673,  0.1428, -0.0000,  0.0764, -0.0350],
          [-0.1370, -0.0935,  0.1949,  0.0000,  0.0239]]],


        [[[-0.1931,  0.0271,  0.0981, -0.0000,  0.0000],
          [-0.0000, -0.0000, -0.0496, -0.0704, -0.1147],
          [-0.1976, -0.0000, -0.0585,  0.1317, -0.1063],
          [-0.0924, -0.0234, -0.1526,  0.0646, -0.0000],
          [-0.1971, -0.0557, -0.0000, -0.0448,  0.0000]]],


        [[[ 0.0000, -0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000, -0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000, -0.0000,  0.0000]]]], device='cuda:0',
       requires_grad=True))]
print(list(module.named_buffers()))
[('bias_mask', tensor([0., 1., 0., 1., 0., 1.], device='cuda:0'))]

对模型中的多个参数进行剪枝

通过指定所需剪枝技术和参数,我们可以轻松剪枝网络中的多个张量,或许根据它们的类型,如以下示例所示。

new_model = LeNet()
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist
dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])

全局剪枝

到目前为止,我们只看了通常被称为”局部”剪枝的内容,即逐个剪枝模型中张量的做法,通过仅将每个条目的统计数据(权重幅度、激活、梯度等)与该张量中的其他条目进行比较。然而,一种常见且可能更强大的技术是一次性剪枝整个模型,例如删除整个模型中最低的20%的连接,而不是删除每层中最低的20%的连接。这可能导致不同层剪枝的百分比不同。让我们通过使用``global_unstructured``从``torch.nn.utils.prune``进行演示。

model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

现在我们可以检查每个剪枝参数所产生的稀疏性,这在每层中可能不会等于20%。然而,全局稀疏性将(近似)为20%。

print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)
Sparsity in conv1.weight: 6.00%
Sparsity in conv2.weight: 13.67%
Sparsity in fc1.weight: 22.06%
Sparsity in fc2.weight: 12.77%
Sparsity in fc3.weight: 9.52%
Global sparsity: 20.00%

使用自定义剪枝函数扩展``torch.nn.utils.prune``

要实现自己的剪枝函数,您可以通过继承``BasePruningMethod``基础类扩展``nn.utils.prune``模块,其他所有剪枝方法也是这样实现的。基础类为您实现了以下方法:__call__apply_maskapplyprune``和``remove。除了一些特殊情况外,您不需要为新剪枝技术重新实现这些方法。然而,您需要实现``__init__``(构造函数)和``compute_mask``(根据剪枝技术逻辑计算给定张量掩码的说明)。此外,您需要指定此技术所实现的剪枝类型(支持选项包括``global``、structured``和``unstructured)。这是为了确定在剪枝迭代应用时如何结合掩码。换句话说,当剪枝一个已剪枝的参数时,当前剪枝技术预计将在参数的未剪枝部分上起作用。指定``PRUNING_TYPE``将使``PruningContainer``(处理剪枝掩码的迭代应用)能够正确识别需要剪枝的参数切片。

例如,假设您想实现一种剪枝技术,该技术会剪枝张量中的每隔一个条目(或者 - 如果该张量之前已剪枝 - 剩余未剪枝部分)。这属于``PRUNING_TYPE=’unstructured’,因为它作用于层中的单个连接,而不是整个单元/通道(’structured’),也不是不同参数之间(’global’``)。

class FooBarPruningMethod(prune.BasePruningMethod):
    """Prune every other entry in a tensor
    """
    PRUNING_TYPE = 'unstructured'

    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        mask.view(-1)[::2] = 0
        return mask

现在,为了将其应用于``nn.Module``中的参数,您还需要提供一个简单的函数来实例化该方法并应用它。

def foobar_unstructured(module, name):
    """Prunes tensor corresponding to parameter called `name` in `module`
    by removing every other entry in the tensors.
    Modifies module in place (and also return the modified module)
    by:
    1) adding a named buffer called `name+'_mask'` corresponding to the
    binary mask applied to the parameter `name` by the pruning method.
    The parameter `name` is replaced by its pruned version, while the
    original (unpruned) parameter is stored in a new parameter named
    `name+'_orig'`.

    Args:
        module (nn.Module): module containing the tensor to prune
        name (string): parameter name within `module` on which pruning
                will act.

    Returns:
        module (nn.Module): modified (i.e. pruned) version of the input
            module

    Examples:
        >>> m = nn.Linear(3, 4)
        >>> foobar_unstructured(m, name='bias')
    """
    FooBarPruningMethod.apply(module, name)
    return module

让我们试试看!

model = LeNet()
foobar_unstructured(model.fc3, name='bias')

print(model.fc3.bias_mask)
tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])

Total running time of the script: ( 0 minutes 0.986 seconds)

通过Sphinx-Gallery生成的图集

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源