李宏毅機器學習課程筆記-13.6模型壓縮代碼實戰

本文爲作者學習李宏毅機器學習課程時參照樣例完成homework7的記錄。

全部課程PPT、數據和代碼下載鏈接:

鏈接:https://pan.baidu.com/s/1n_N7aoaNxxwqO03EmV5Bjg 提取碼:tpmc

代碼倉庫:https://github.com/chouxianyu/LHY_ML2020_Codes

任務描述

通過Architecture Design、Knowledge Distillation、Network Pruning和Weight Quantization這4種模型壓縮策略,用一個非常小的model完成homework3中食物圖片分類的任務。

1.Architecture Design

MobileNet提出了Depthwise & Pointwise Convolution。我們在這裏實現MobileNet v1這個比較小的network,後續使用Knowledge Distillation策略訓練它,然後對它進行剪枝和量化。

2.Knowledge Distillation

將ResNet18作爲Teacher Net(使用torchvision中的ResNet18,僅將num_classes改成11,加載助教訓練好的Accuracy約爲88.4%的參數),將上一步(1.Architecture Design)設計的小model作爲Student Net,使用Knowledge_Distillation策略訓練Student Net。

Loss計算方法爲\(Loss = \alpha T^2 \times KL(\frac{\text{Teacher's Logits}}{T} || \frac{\text{Student's Logits}}{T}) + (1-\alpha)(\text{Original Loss})\),關於爲什麼要對student進行logsoftmax可見https://github.com/peterliht/knowledge-distillation-pytorch/issues/2

論文《Distilling the Knowledge in a Neural Network》:https://arxiv.org/abs/1503.02531

3.Network Pruning

對上一步(2.Knowledge_Distillation)訓練好的Student Net做剪枝。

根據論文《Learning Efficient Convolutional Networks through Network Slimming》,論文鏈接:https://arxiv.org/abs/1708.06519
BatchNorm層中的gamma值和一些特定卷積核(或者全連接層的一個神經元)相關聯,因此可以使用BatchNorm層中的gamma值判斷相關通道的重要性。

Student Net中CNN部分有幾個結構相同的Sequential,其結構、權重名稱、實現代碼、權重形狀如下表所示。

# name meaning code weight shape
0 cnn.{i}.0 Depthwise Convolution nn.Conv2d(x, x, 3, 1, 1, group=x) (x, 1, 3, 3)
1 cnn.{i}.1 Batch Normalization nn.BatchNorm2d(x) (x)
2 ReLU6 nn.ReLU6
3 cnn.{i}.3 Pointwise Convolution nn.Conv2d(x, y, 1), (y, x, 1, 1)
4 MaxPooling nn.MaxPool2d(2, 2, 0)

獨立剪枝prune_count次,每次剪枝的剪枝率按prune_rate逐漸增大,剪枝後微調finetune_epochs個epoch。

4.Weight Quantization

對第二步(2.Knowledge_Distillation)訓練好的Student Net做量化(用更少的bit表示一個value)。

torch預設的FloatTensor是32bit,而FloatTensor最低可以是16bit。

如何將32bit轉成8bit的int呢?對每個weight進行min-max normalization,然後乘以\(2^8-1\)再四捨五入成整數,這樣就可以轉成uint8了。

數據集描述

數據集爲homework3中食物圖片分類數據集。

11個圖片類別,訓練集中有9866張圖片,驗證集中有3430張圖片,測試集中有3347張圖片。

訓練集和驗證集中圖片命名格式爲類別_編號.jpg,編號不重要。

代碼

https://github.com/chouxianyu/LHY_ML2020_Codes/tree/master/hw7_NetworkCompression


Github(github.com):@chouxianyu

Github Pages(github.io):@臭鹹魚

知乎(zhihu.com):@臭鹹魚

博客園(cnblogs.com):@臭鹹魚

B站(bilibili.com):@絕版臭鹹魚

微信公衆號:@臭鹹魚

轉載請註明出處,歡迎討論和交流!


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章