其實應該最先寫這篇文章的總結的,之前看了忘了記錄
Motivation
one hot label會將所有不正確的類別概率都設置爲0,而一個好的模型預測出來的結果,這些不正確的類別概率是有不同的,他們之間概率的相對大小其實蘊含了更多的信息,代表着模型是如何泛化判別的。
比如一輛轎車,一個模型更有可能把它預測成卡車而不是貓,這其實給出了比one hot label更多的信息即轎車和卡車更像,而和貓不像。
如果一個大的模型做到了很好的泛化性能,那我們可以用一個小的模型去模擬他的泛化結果去達到較好的效果
Method
Loss = CE(softmax(predict), one hot label) + alpha * T * T * CE(softmax(predict/T), soft target)
T作爲一個超參,當T很大時,qi會更加soft,比如T趨於無窮大,則qi=(1/n, 1/n…)
當T較小時(比如T=1),需要去匹配更多的不正確類別的概率。如果student和teacher性能相差較大,可設置T爲中等大小
VS Matching logits(Caruana提出的)
Matching logits(https://www.cs.cornell.edu/~caruana/compression.kdd06.pdf) is a special case of distillation
C = CE(softmax(predict/T), soft target),根據CE的求導公式得
如果temperature T比logits的量級(magnitude)要大得多,那麼zi/T->0,zi<0時從左邊趨近0,>0時從右邊趨近0,所有e^(zi/T) =1+zi/T
假設對於每一個transfer case,都有logits的均值爲0,所以上式可以簡化爲
所以,如果temperature T很高,如果對於每一個transfer case,都有logits的均值爲0,那麼distillation就等價於最小化1/2(zi−vi)^2,也就是Caruana提出的使得複雜模型的logits和小模型的logits的平方差最小
https://daiwk.github.io/posts/dl-knowledge-distill.html
Soft Targets as Regularizers
用soft target進行訓練避免了過擬合