###好好好好#####模型蒸餾(Distil)及mnist實踐

結論:蒸餾是個好方法。

模型壓縮/蒸餾在論文《Model Compression》及《Distilling the Knowledge in a Neural Network》提及,下面介紹後者及使用keras測試mnist數據集。

蒸餾:使用小模型模擬大模型的泛性。

通常,我們訓練mnist時,target是分類標籤,在蒸餾模型時,使用的是教師模型的輸出概率分佈作爲“soft target”。也即損失爲學生網絡與教師網絡輸出的交叉熵(這裏採用DistilBert論文中的策略,此論文不同)。

當訓練好教師網絡後,我們可以不再需要分類標籤,只需要比較2個網絡的輸出概率分佈。當然可以在損失裏再加上學生網絡的分類損失,論文也提到可以進一步優化。

如圖,將softmax公式稍微變換一下,目的是使得輸出更小,softmax後就更爲平滑。

 

 

 論文的損失定義

 

 

本文代碼使用的損失爲p和q的交叉熵

代碼測試部分

1,教師網絡,測試精度99.46%,已經相當好了,可訓練參數858,618。

複製代碼

# 教師網絡
inputs=Input((28,28,1))
x=Conv2D(64,3)(inputs)
x=BatchNormalization(center=True,scale=False)(x)
x=Activation('relu')(x)
x=Conv2D(64,3,strides=2)(x)
x=BatchNormalization(center=True,scale=False)(x)
x=Activation('relu')(x)
x=Conv2D(128,5)(x)
x=BatchNormalization(center=True,scale=False)(x)
x=Activation('relu')(x)
x=Conv2D(128,5)(x)
x=BatchNormalization(center=True,scale=False)(x)
x=Activation('relu')(x)
x=Flatten()(x)
x=Dense(100)(x)
x=BatchNormalization(center=True,scale=False)(x)
x=Activation('relu')(x)
x=Dropout(0.3)(x)
x=Dense(10,activation='softmax')(x)
model=Model(inputs,x)
model.compile(optimizer=optimizers.SGD(momentum=0.8,nesterov=True),loss=categorical_crossentropy,metrics=['accuracy'])
model.summary()
model.fit(X_train,y_train,batch_size=128,epochs=30,validation_split=0.2,verbose=2)
# 重新編譯後,完整數據集訓練18輪,原始16輪後開始過擬合,訓練集變大後不易過擬合,這裏增加2輪
model.fit(X_train,y_train,batch_size=128,epochs=18,verbose=2)
model.evaluate(X_test,y_test)# 99.46%

複製代碼

2,學生網絡,測試精度99.24%,可訓練參數164,650,不到原來的1/5。

複製代碼

# 定義溫度
tempetature=3
# 學生網絡
inputs=Input((28,28,1))
x=Conv2D(16,3)(inputs)
x=BatchNormalization(center=True,scale=False)(x)
x=Activation('relu')(x)
x=Conv2D(16,3)(x)
x=BatchNormalization(center=True,scale=False)(x)
x=Activation('relu')(x)
x=Conv2D(32,5)(x)
x=BatchNormalization(center=True,scale=False)(x)
x=Activation('relu')(x)
x=Conv2D(32,5,strides=2)(x)
x=BatchNormalization(center=True,scale=False)(x)
x=Activation('relu')(x)
x=Flatten()(x)
x=Dense(60)(x)
x=BatchNormalization(center=True,scale=False)(x)
x=Activation('relu')(x)
x=Dropout(0.3)(x)
x=Dense(10,activation='softmax')(x)
x=Lambda(lambda t:t/tempetature)(x)# softmax後除以溫度,使得更平滑
student=Model(inputs,x)
student.compile(optimizer=optimizers.SGD(momentum=0.9,nesterov=True),loss=categorical_crossentropy,metrics=['accuracy'])
# 使用老師和學生概率分佈結果的軟交叉熵,即除以溫度後的交叉熵
student.fit(X_train,model.predict(X_train)/tempetature,batch_size=128,epochs=30,verbose=2)

複製代碼

最後測試一下

student.evaluate(X_test,y_test/tempetature)# 99.24%

3,繼續減少參數,去除Dropout和BN,前期卷積使用步長,精度98.80%。參數72,334,大約原來的1/12。

複製代碼

# 定義溫度
tempetature=3
# 學生網絡
inputs=Input((28,28,1))
x=Conv2D(16,3,activation='relu')(inputs)
# x=BatchNormalization(center=True,scale=False)(x)
# x=Activation('relu')(x)
x=Conv2D(16,3,strides=2,activation='relu')(x)
# x=BatchNormalization(center=True,scale=False)(x)
# x=Activation('relu')(x)
x=Conv2D(32,5,activation='relu')(x)
# x=BatchNormalization(center=True,scale=False)(x)
# x=Activation('relu')(x)
x=Conv2D(32,5,activation='relu')(x)
# x=BatchNormalization(center=True,scale=False)(x)
# x=Activation('relu')(x)
x=Flatten()(x)
x=Dense(60,activation='relu')(x)
# x=BatchNormalization(center=True,scale=False)(x)
# x=Activation('relu')(x)
# x=Dropout(0.3)(x)
x=Dense(10,activation='softmax')(x)
x=Lambda(lambda t:t/tempetature)(x)# softmax後除以溫度,使得更平滑
student=Model(inputs,x)
student.compile(optimizer=optimizers.SGD(momentum=0.9,nesterov=True),loss=categorical_crossentropy,metrics=['accuracy'])
student.fit(X_train,model.predict(X_train)/tempetature,batch_size=128,epochs=30,verbose=2)
student.evaluate(X_test,y_test/tempetature)# 98.80%

複製代碼

 4,在3的基礎上,loss部分加上學生網絡與分類標籤的損失,測試精度98.79%。基本沒變化,此時這個損失倒不太重要了。

複製代碼

# 凍結老師網絡
model.trainable=False
# 定義溫度
temperature=3
# 自定義loss,加上學生網絡與真實標籤的損失,這個損失計算應使學生網絡溫度爲1,即這個損失不用除以溫度
class Calculate_loss(Layer):
    def __init__(self,T,label_loss_weight,**kwargs):
        '''
        T: temperature for soft-target
        label_loss_weight: weight for loss between student-net and labels, could be small because the other loss is more important
        '''
        self.T=T
        self.label_loss_weight=label_loss_weight
        super(Calculate_loss,self).__init__(**kwargs)
    def call(self,inputs):
        student_output=inputs[0]
        teacher_output=inputs[1]
        labels=inputs[2]
        loss_1=categorical_crossentropy(teacher_output/self.T,student_output/self.T)
        loss_2=self.label_loss_weight*categorical_crossentropy(labels,student_output)
        self.add_loss(loss_1+loss_2,inputs=inputs)
        return labels
# 將標籤轉化爲tensor輸入
y_inputs=Input((10,))# 類似placeholder作用
y=Lambda(lambda t:t)(y_inputs)
# 學生網絡
inputs=Input((28,28,1))
x=Conv2D(16,3,activation='relu')(inputs)
x=Conv2D(16,3,strides=2,activation='relu')(x)
x=Conv2D(32,5,activation='relu')(x)
x=Conv2D(32,5,activation='relu')(x)
x=Flatten()(x)
x=Dense(60,activation='relu')(x)
x=Dense(10,activation='softmax')(x)
x=Calculate_loss(T=temperature,label_loss_weight=0.1)([x,model(inputs),y])
student=Model([inputs,y_inputs],x)
student.compile(optimizer=optimizers.SGD(momentum=0.9,nesterov=True),loss=None)
student.summary()
student.fit(x=[X_train,y_train],y=None,batch_size=128,epochs=30,verbose=2)

複製代碼

提取出預測模型,標籤one-hot化了,重新加載一下

複製代碼

softmax_layer=student.layers[-4]

predict_model=Model(inputs,softmax_layer.output)

res=predict_model.predict(X_test)

import numpy as np
result=[np.argmax(a) for a in res]

(x_train,y_train),(x_test,y_test)=mnist.load_data()

from sklearn.metrics import accuracy_score
accuracy_score(y_test,result)# 98.79%

複製代碼

 5,作爲對比,相同網絡不使用蒸餾,測試精度98.4%

複製代碼

# 對應上面,不使用蒸餾,精度爲98.4%
inputs=Input((28,28,1))
x=Conv2D(16,3,activation='relu')(inputs)
x=Conv2D(16,3,strides=2,activation='relu')(x)
x=Conv2D(32,5,activation='relu')(x)
x=Conv2D(32,5,activation='relu')(x)
x=Flatten()(x)
x=Dense(60,activation='relu')(x)
x=Dense(10,activation='softmax')(x)
student=Model(inputs,x)
student.compile(optimizer=optimizers.SGD(momentum=0.9,nesterov=True),loss=categorical_crossentropy,metrics=['accuracy'])
student.summary()
# student.fit(X_train,y_train,validation_split=0.2,batch_size=128,epochs=30,verbose=2)
student.fit(X_train,y_train,batch_size=128,epochs=10,verbose=2)
student.evaluate(X_test,y_test)

複製代碼

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章