Paper: Correlation Congruence for Knowledge Distillation
1, Motivation:
通常情況下KD的teacher模型的特徵空間沒考慮類內類間的分佈,student模型也將缺少我們期望的類內類間的分佈特性。
Usually, the embedding space of teacher possesses the characteristic that intra-class instances cohere together while inter-class instances separate from each other. But its counterpart of student model trained by instance congruence would lack such desired characteristic.
2,Contribution:
- 提出相關一致性知識蒸餾(CCKD),它不僅關注實例一致性,而且關注相關一致性。(instance congruence通過mini-batch的PK或聚類實現。correlation congruence通過樣本I,J直接的相關性損失函數的約束實現實現。)
- 將mini-batch中的相關性計算直接轉成mini-batch的的大矩陣進行,減少計算量。
- 採用不同的mini-batch sampler strategies.
- 在CIFAR-100, ImageNet-1K, person reidentification and face recognition進行實驗。
3,論文框架:
3.3. Correlation Congruence
相關一致性知識蒸餾
- 提取特徵
- 映射embedding feature space
映射函數can be any correlation metric, and we will introduce three metric for capturing the correlation between instances in next section.
3,計算 correlation matrix
相關一致性:公式
Gaussian RBF is more flexible and powerful in capturing the complex non-linear relationship between instances.(論文最後採用高斯kernel計算相關性,但計算量真的很大。。)
LOSS FUCTION:
(比傳統的KD多了一個相關一致性的損失函數約束)
4,實驗結果:
可以看到加約束的,intra-class距離更大
5,Setting:
On CIFar-100, ImageNet-1K and MSMT17, Original Knowledge distillation (KD) [15] and cross-entropy (CE) are chosen as the baselines. For face recognition, ArcFace loss [5] and L2-mimic loss [21, 23] are adopt. We compare CCKD with several state-of-the-art distillation related methods, including attention transfer (AT) [37], deel mutual learning (DML) [39] and conditional adversarial network (Adv) [35]. For attention transfer, we add it for last two blocks as suggested in [37]. For adversarial training, the discriminator consists of FC(128 × 64) + BN + ReLU + FC (64 × 2) + Sigmoid activation layers, and we adopt BinaryCrossEntropy loss to train it.
ResNet-50 is used as the teacher network and ResNet-18 as student network. The dimension of the feature representation is set to 256. We set the weight decay to 5e-4, batch size to 40, and use stochastic gradient descent with momentum. The learning rate is set as 0.0003, then divided by 10 at 45, 60 epochs, totally 90 epochs.