MetaPruning: 基于元学习的自动化神经网络通道剪枝

本文出自论文MetaPruning: Meta Learning for Automatic Neural Network Channel Pruning,提出来一个最新的元学习方法,对非常深的神经网络进行自动化通道剪枝。

本文提出来一个最新的元学习方法,对非常深的神经网络进行自动化通道剪枝。首先训练出一个PruningNet,对于给定目标网络的任何剪枝结构都可以生成权重参数。我们使用一个简单的随机结构采样方法来训练PruningNet,然后应用一个进化过程来搜索性能好的剪枝网络。这个搜索方法是非常高效的,因为权重直接通过训练好的PruningNet生成,并不需要在搜索时间中进行任何的微调。只需要为目标网络训练处一个简单的PruningNet,我们可以在不同的人工约束下搜索不同的剪枝网络。与当前最先进的剪枝方式相比,MetaPruning在MobileNet V1/V2和ResNet上有着最好的性能表现。



一、简介

  1. 通道剪枝作为一种神经网络的压缩方法被广泛的实现和应用,它通常包含三个阶段:训练一个大的超参数化网络,修剪次重要的权重或者通道,最后通过微调或者重训练剪枝网络来获得最终的剪枝网络。第二个阶段通常执行迭代式的逐层剪枝,然后快速微调或者权重重建来重获得精度。AutoML中利用自动寻找最优结构的特点,基于一个反馈循环或者强化学习,脱离了人工设计的局限并弥补了剪枝算法依赖数据的不足。
  2. 最近的研究表明通道剪枝的本质是找到好的剪枝结构-逐层的通道数量。由于迭代式寻找最优结构的计算代价很高,因此提出来训练PruningNet,来对于所有候选的剪枝网络架构,可以生成权重参数,这样我们能通过在验证数据集上评估其精度来搜索到性能好的架构。为了训练PruningNet,我们使用一个随机结构采样方法,使用相应的网络编码向量来生成剪枝网络的权重,即每一层的通道数量。通过随机输入不同的网络编码向量,PruningNet逐步学习生成不同剪枝结构的权重参数。在训练过程结束后,我们通过一个进化搜索方法来搜索到性能好的剪枝网络,可以灵活地结合到不同的约束例如计算浮点数或者硬件延迟。另外,通过决定每一层或每个阶段的通道,可以直接搜索到最好的剪枝网络,因此可以在shortcut结构中修剪通道。这种神经网络压缩方法被称作MetaPruning。
  3. MetaPruning的两个阶段:(1)训练一个PruningNet:在每次迭代过程中,一个网络编码向量(每层的通道数量)被随机生成,剪枝网络也相应地被构建出来,PruningNet将网络编码向量作为输入,来生成剪枝网络的权重参数;(2)搜索最佳剪枝网络:通过不同的网络编码向量构建了许多剪枝网络,并利用剪枝网络预测的权重对验证集的优劣进行了评估,在搜索时间内无微调或重训练过程。
    MetaPruning两阶段
  4. 将此方法应用于MobieNets和ResNet上,在相同的浮点数下,我们的精度比MobileNet V1高2.2%到6.6%,比MobileNet V2高0.7%到3.7%,比ResNet-50高0.6%到1.4%。在相同的延迟下,比MobileNet V1高2.1%到9.0%,比MobileNet V2高1.2%到9.9%。
  5. 本文的主要贡献点:(1)提出一个元学习方法MetaPruning来用于通道剪枝,其中心思想是学习一个元网络来生成不同剪枝结构的权重参数;(2)节省了超参数调优中的人力过程,允许使用所需要的度量标准来直接优化;(3)可以很容易地在搜索所需结构时实施约束,而不需要手动调整强化学习的参数;(4)可以不费力地修剪像ResNet结构这样的short-cuts的通道。

二、相关工作

  1. Pruning:网络剪枝对于深度网络的冗余度去除是一个普遍的方法。在权重修剪过程中,通常会剪去单个权重来压缩模型大小,但同时会导致非结构化的稀疏过滤器。传统的通道剪枝方法是根据每个通道的重要程度,以迭代方式修剪通道,或者添加一个数据驱动的稀疏度。
  2. AutoML:该方法将多设备上的实时推理延迟考虑在内,通过强化学习或者一个自动化的反馈循环在一个网络的不同层上迭代式修剪通道。与先前的AutoML剪枝方法相比,MetaPruning方法在精度满足约束条件方面具有较高的灵活性,并具有对short-cut中的通道进行修剪的能力。
  3. Meta Learning:它指代着学习观察不同的机器学习方法如何在不同的学习任务上执行。在本文中我们使用meta learning来进行权重预测,权重预测表示一个神经网络的权重被另一个神经网络所预测,而不是直接学习得到。
  4. Neural Architecture Search:使用强化学习、遗传算法或者基于梯度的方法找到最优的网络结构和超参数。通过与drop-path联合训练多项选择,它可以在训练过的网络中搜索到最高精度的路径。调整通道宽度也包含在一些神经架构搜索方法中。我们所提出的针对通道剪枝的MetaPruning方法能够通过训练PruningNet进行权重预测来解决这一连续的通道剪枝挑战问题。

三、方法

  1. 我们将通道剪枝问题用公式表示为:(c1,c2,...cl)=argminc1,c2,...clL(A(c1,c2,...cl;w))C<constraint(c_1,c_2,...c_l)^*=\mathop{\arg\min}\limits_{c_1,c_2,...c_l}{L}(A(c_1,c_2,...c_l;w))\quad C<constraint, 其中A是剪枝前的网络,我们尝试找到剪枝网络的通道宽度(从第一层到第L层),在权重被训练后有着最小的损失,同时C满足所规定的的约束(FLOPs或者延迟)。为此,我们提出构建一个PruningNet,一种元网络,可以通过在验证集上的评估快速获得所有可能剪枝网络结构的优劣度。然后我们可以应用任何搜索方法(比如进化算法)来搜索到最佳的剪枝网络。
  2. PruningNet训练:PruningNet是一个元网络,将一个网络编码向量(c1,c2,...clc_1,c_2,...c_l)作为输入,然后输出剪枝网络的权重,可表示为:W=PruningNet(c1,c2,...cl).W=PruningNet(c_1,c_2,...c_l). 一个PruningNet block由两个全连接层组成,在前向传递过程中,PruningNet将网络编码向量作为输入,然后生成权重矩阵。与此同时,一个剪枝网络被构造出,其每一层的输出通道宽度等同于网络编码向量中的元素。生成的权重矩阵被裁剪来匹配剪枝网络中输入输出通道的数量。在后向传递过程中,并没有更新剪枝网络的权重,而是计算PruningNet里权重的梯度。为了训练PruningNet,我们提出了随机结构采样,在训练阶段的每次迭代过程中,网络编码向量被随机生成来选择每层的通道数量。有着不同的网络编码,不同的剪枝网络被构建出来,相应的权重由PruningNet来提供。通过使用不同的编码向量随机训练,PruningNet学习预测各种不同剪枝网络的合理权重。

PruningNet随机训练方法网络架构以及reshape操作
5. 剪枝网络搜索:在PruningNet训练完后,我们可以通过输入网络编码到PruningNet中,生成相应的权重和在验证集上进行评估工作,来获取每个可能剪枝网络的精度。由于网络编码向量数量巨大的问题,为了在约束条件下找到高精度的剪枝网络,我们使用一个进化搜索,可以很容易地合并软硬性约束。每个剪枝网络被编码成一个包含每层通道数量的向量,被命名为剪枝网络的基因。我们首先随机选择大量的基因,通过做评估来获得相应剪枝网络的精度。然后前K个最高精度的基因被挑选出来,使用交叉和变异方法生成新的基因。变异操作通过随机改变基因中的元素比例来执行,交叉操作通过随机重组两个亲本基因的基因来产生后代。通过迭代进行这个过程,我们可以获得满足约束条件的基因,同时得到最高精度。
进化搜索算法

四、实验结果

  1. MetaPruning on MobileNets and ResNet:对于没有short-cut结构的网络MobileNet V1,我们裁剪原始权重矩阵的左上边,来匹配输入和输出通道。在MobileNet V2中,每个阶段都从匹配两个阶段之间的维度瓶颈块开始。为了修剪包含shortcut的结构,我们生成两个网络编码向量,一个对总体阶段的输出通道进行编码来匹配shortcut里的通道,另一个对每个block的中间通道进行编码。在PruningNet中,我们首先将网络编码向量解码为每个块的输入输出和中间通道压缩比,然后我们生成那个block块中的相应权重矩阵。ResNet和MobileNet V2的构建过程相同。在这里插入图片描述
  2. FLOPs约束下的剪枝效果比较:使用MetaPruning学习得到的剪枝方法,与0.25x下的MobileNet V1相比我们获得了6.6%的精度提升,同样与MobileNet V2和ResNet相比均获得了很好的提升。在与最先进的AutoML剪枝方法相比中,MetaPruning获得了较好的效果,它还消除了人工调整强化学习超参数的劣势。
    FLOPs约束下的性能比较
  3. 延迟约束下的剪枝效果比较:在合理的假设下,每一层的执行时间是独立的,我们可以通过将网络中所有层的运行时间相加来得到网络延迟。通过估计在目标设备上执行不同输入和输出通道宽度的卷积层的延迟,我们首先构建一个look-up表。然后我们可以从这个look-up表中计算得到构建网络的延迟。在相同延迟下利用MetaPruning得到的剪枝网络可以获得显著的更高精度。延迟约束下的性能比较
  4. 剪枝网络可视化:(1)当向下采样以步长为2的深度卷积方式进行时,需要使用更多的通道数量来携带信息,因此MetaPruning自动学会在下采样过程中保存更多的通道。(2)MetaPruning方法自动学会在靠后的阶段中修剪较少的shortcut通道数。可视化效果
  5. 权重预测的效果:元学习中的权重预测机制对不同剪枝结构的权重进行了有效的去相关处理,从而使PruningNet获得更高的精度。权重预测效果展示

五、结论

本文我们提出了MetaPruning用作通道剪枝,其具有以下优点:(1)与统一剪枝基线以及最先进的通道剪枝方法(传统的和AutoML)相比,其具有更高的精度;(2)在不引入额外超参数的情况下,它可以灵活地针对不同的约束条件进行优化;(3)像ResNet这样的架构可以有效地被处理;(4)整个过程是非常高效的。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章