基于生成对抗的结构剪枝——Generative Adversarial Learning

"Towards Optimal Structured CNN Pruning via Generative Adversarial Learning" 这篇文章提出了非常新颖的结构剪枝策略:基于生成对抗的思路,将剪枝网络设置为生成器(Generator),其输出特征作为Fake,并设置Soft Mask门控一些异质结构的输出(如通道、分支、网络层或模块等);将预训练模型设置为Baseline,Baseline的输出特征作为Real;再引入判别器(Discriminator)与正则化约束,一方面对齐生成器与Baseline的输出,另一方面驱使生成器中的Soft Mask稀疏化(mask value介于0到1之间),最终达到低精度损失的结构剪枝的目的。基于GAL(Generative Adversarial Learning)的剪枝策略总体如下图所示:

基于GAL的剪枝策略能够克服现有结构剪枝技术的不足,不足之处具体表现在:1)相对耗时的多阶段优化,迭代执行剪枝与fine-tuning;2)通常采用hard pruning mask,不够松弛、较难优化学习;3)训练或正则化过程依赖于样本标注。针对这些不足,基于GAL的剪枝策略,首先通过Baseline与Discriminator的辅助作用,能够在对抗学习过程中避免样本标注的使用;其次,Soft Pruning Mask的使用,使得正则化过程变得更加松弛、更容易学习收敛;另外,对抗训练与正则化过程是端到端的、非逐层实施的,并且能够自动完成最优网络结构探索、以及类似于知识蒸馏的特征迁移(Baseline -> Generator)。基于GAL的剪枝策略涉及的符号标记如下,fb(x)与fg(x)分别表示Baseline与Generator输出的特征矢量(非Softmax层):

通过Soft Mask(标记为m)的稀疏化,可以剪除包括通道、分支或Block等在内的基本结构。为了确保剪枝之后,剪枝模型仍能获得与Baseline相接近的推理精度,基于GAL的剪枝策略首先对Soft Mask施加L1正则化;其次引入判别器(Discriminator),与剪枝模型(Generator)构成了生成对抗学习,在对抗学习过程中将Baseline输出的特征矢量作为监督信息,用以对齐Baseline与剪枝模型的特征输出。在对抗学习与正则化过程中,Baseline的参数固定、不需要更新,而剪枝模型参数WG、Soft Mask以及判别器参数WD需要更新,具体的优化问题如下:

上式中,表示判别器损失,用来引导判别器提升鉴别能力,Baseline的输出表示Real,而剪枝模型(Generator)的输出表示Fake,当二者输出真假难辨时,达到对齐到输出特征的目的:

式(1)中数据损失用来进一步对齐Baseline与Generator的输出特征,具体表示为Baseline与Generator输出特征之间的MSE损失:

式(1)中正则化损失主要分为三部分,分别表示对WGmWD的正则化约束:

上式中R(WG)表示一般的weight decay,且通常是L2正则化;R(m)表示对Soft Mask的L1正则化;R(WD)表示对判别器的正则化约束,用以防止判别器主导训练学习,并且主要采用对抗正则化,促进判别器与生成器之间的对抗竞争:

如果直接采用SGD求解式(1)的优化问题,Soft Mask较难稀疏化(零值较难获得)。此时通常需要设置一个阈值,并将低于阈值的Mask Value或Scaling Factor置零,达到剪枝的目的,然而剪枝网络的推理精度会明显低于Baseline。为解决该问题,文章引入FISTA方法用以求解式(1)的优化问题,具体如下(i=j=1):

优化策略主要包含两个交替执行的阶段:1)第一个阶段固定Gm,通过对抗训练更新判别器D,损失函数包含对抗损失与对抗正则项;2)第二阶段固定D,更新生成器G与Soft Mask,损失函数包含对抗损失中的fg相关项、fb与fg的MSE损失以及Gm的正则项。最终,完成Soft Mask的稀疏化之后,便可以按照门控方式,完成channel、branch或block的规整剪枝。

实验结果具体见文章的实验部分,值得注意的是:1)对判别器(Discriminator)施加正则化约束时,对抗正则化(Adversarial Regularization)相比于L1、L2正则化,能够起到更好的正则化效果,即达到更高稀疏度的同时,网络的推理精度也更高;2)相比于不使用GAN的特征迁移学习,GAL能够起到更好的监督效果,并且GAL是label-free的,能够更好地激励Generator输出与Baseline相接近的特征。

Paper地址:https://arxiv.org/abs/1903.09291

GitHub地址(PyTorch):https://github.com/ShaohuiLin/GAL

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章