由於BERT參數衆多,模型龐大,訓練與推理速度較慢,在一些實時性要求較高應用場景無法滿足需求,最近開始探索BERT輕量化部署
BERT輕量化的方式:
- 低精度量化。在模型訓練和推理中使用低精度(FP16甚至INT8、二值網絡)表示取代原有精度(FP32)表示。
- 模型裁剪和剪枝。減少模型層數和參數規模。
- 模型蒸餾。通過知識蒸餾方法[22]基於原始BERT模型蒸餾出符合上線要求的小模型。
本文主要分享下BERT的剪枝實踐,代碼來源於Rasa,對剪枝部分代碼進行了剝離和修改,然後進行試驗,模型是變小了,但是推理時間反而增加(囧),Rasa的代碼還在研究中,先放一部分試驗結果。
剪枝的方式兩種:
- neuron pruning
- weigth pruning
實驗結果:
結論就是剪枝完,並沒啥用??繼續看下代碼吧。哪位嘗試過rasa-bert剪枝代碼的同學,麻煩指導下。
部分試驗過程:
Variable: bert/encoder/layer_9/intermediate/dense/mask:0
Shape: (768, 3072)
Element sparsity: 50.0%
Column sparsity: 0.0% (3072/3072)
Variable: bert/encoder/layer_9/output/dense/mask:0
Shape: (3072, 768)
Element sparsity: 50.0%
Column sparsity: 0.0% (768/768)
Variable: bert/pooler/dense/mask:0
Shape: (768, 768)
Element sparsity: 50.0%
Column sparsity: 0.0% (768/768)
###########################################################################
Overall:
Element sparsity: 50.0%
Column sparsity: 0.0%
Variable: bert/encoder/layer_9/attention/self/query/mask:0
Shape: (768, 768)
Element sparsity: 63.8%
Column sparsity: 63.8% (278/768)
Variable: bert/encoder/layer_9/attention/self/value/mask:0
Shape: (768, 768)
Element sparsity: 29.3%
Column sparsity: 29.3% (543/768)
Variable: bert/encoder/layer_9/intermediate/dense/mask:0
Shape: (768, 3072)
Element sparsity: 76.6%
Column sparsity: 76.6% (720/3072)
Variable: bert/encoder/layer_9/output/dense/mask:0
Shape: (3072, 768)
Element sparsity: 16.3%
Column sparsity: 16.3% (643/768)
Variable: bert/pooler/dense/mask:0
Shape: (768, 768)
Element sparsity: 4.7%
Column sparsity: 4.7% (732/768)
###########################################################################
Overall:
Element sparsity: 41.1%
Column sparsity: 50.0%
補充下rasa的代碼位置:https://github.com/RasaHQ/rasa/tree/nlu_lstm-compressing_transformers/rasa/nlu/classifiers
文章: https://blog.rasa.com/compressing-bert-for-faster-prediction-2/
抽離出來後,需要對代碼進行修改,才能跑通。