一言以蔽之,CVPR2018,分類softmax的替代品,或許不能明顯提點,但泛化性能更佳。
其實16個月前就嘗試過了,近期正好又需要用到,故而來整理下。
原作者只給出了python2的代碼,並且未給出特定格式的數據集,修改了訓練入口及兼容了python3。
詳細實現及代碼見:https://github.com/zmdsjtu/Convolutional-Prototype-Learning
二維mnist可視化後看着還是非常神清氣爽的:
- 從左到右分別爲softmax,CPL,GCPL
- 實際用其他數據集復現出來的效果也類似
Prototype Learning
中心思想——學習出m箇中心點(m可以和類別數量一致,也可以更多)
CNN網絡將原數據映射到一個n維空間,類似於LVQ聚類算法,構建m個“中心點”用以代表各類,在反向傳播的時候不斷迭代更新各個“中心點”的位置,迫使類內更聚合,類間距離更大,關鍵點就在於如何設計loss
網絡前向的時候,距離最近的中心點代表的類別即爲最終結果
Loss設計
MCE/MCL/DCE三選一
- Minimum classification error loss(MCE)
- Margin based classification loss(MCL)
- Distance based cross entropy loss(DCE)
加上pl組成最終版loss
- Generalized CPL with prototype loss(GCPL)
作者實現的mnist採用的DCE+0.01PL,都可以試一下,DCE收斂速度最快
詳細loss話不多說上代碼:
Minimum classification error loss(MCE)
確實能拉近類內距離和增大類間距離,這裏還有公式推導竟然,然而不重要,直接看loss代碼
- 如果分類正確,標籤距離-第二近的距離,爲負,sigmoid後作爲loss
- 如果分類錯誤,標籤距離-最近距離,爲正,sigmoid後作爲loss
def mce_loss(features, labels, centers, epsilon):
# 如果10類,爲一個N * 10的矩陣
dist = distance(features, centers)
values, indexes = tf.nn.top_k(-dist, k=2, sorted=True)
top2 = -values
d_1 = top2[:, 0]
d_2 = top2[:, 1]
row_idx = tf.range(tf.shape(labels)[0], dtype=tf.int32)
idx = tf.stack([row_idx, labels], axis=1)
# d_y 爲標籤的距離
d_y = tf.gather_nd(dist, idx, name='dy')
# indicator 正確的爲1,錯誤的爲0
indicator = tf.cast(tf.nn.in_top_k(-dist, labels, k=1), tf.float32)
# d_c,如果label正確爲第二近的距離;如果錯誤,爲最近的距離
d_c = indicator * d_2 + (1 - indicator) * d_1
# 如果標籤正確,標籤距離-第二近距離,爲負;
# 如果標籤錯誤,標籤距離-最近距離,爲正
measure = d_y - d_c
loss = tf.sigmoid(epsilon * measure, name='loss')
mean_loss = tf.reduce_mean(loss, name='mean_loss')
return mean_loss
Margin based classification loss(MCL)
加入了margin容錯度,sigmoid換成了relu,其他一致
def mcl_loss(features, labels, centers, margin):
dist = distance(features, centers)
values, indexes = tf.nn.top_k(-dist, k=2, sorted=True)
top2 = -values
d_1 = top2[:, 0]
d_2 = top2[:, 1]
row_idx = tf.range(tf.shape(labels)[0], dtype=tf.int32)
idx = tf.stack([row_idx, labels], axis=1)
d_y = tf.gather_nd(dist, idx, name='dy')
indicator = tf.cast(tf.nn.in_top_k(-dist, labels, k=1), tf.float32)
d_c = indicator * d_2 + (1 - indicator) * d_1
# 只考慮正確的,順便加上了“軟間隔”margin
loss = tf.nn.relu(d_y - d_c + margin, name='loss')
mean_loss = tf.reduce_mean(loss, name='mean_loss')
return mean_loss
Distance based cross entropy loss(DCE)
這個代碼看着非常舒服,距離的負數作爲logits算softmax loss
def dce_loss(features, labels, centers, t, weights=None):
dist = distance(features, centers)
logits = -dist / t
mean_loss = softmax_loss(logits, labels, weights)
return mean_loss
Generalized CPL with prototype loss(GCPL)
上述三個loss加一些正則,MCE/MCL/DCE + λPL
# prototype loss (PL)
def pl_loss(features, labels, centers):
batch_num = tf.cast(tf.shape(features)[0], tf.float32)
batch_centers = tf.gather(centers, labels)
dis = features - batch_centers
return tf.div(tf.nn.l2_loss(dis), batch_num)
然後這片文章就沒什麼了
順便放下結果:
- 泛化能力更佳,相同的數據量下,GCPL表現更佳(如果足夠其實差別不大,233)
- 新增類別可視化,中間那一坨爲新增類別(還是分得很開的,可以用來作爲unknown)
- out-of-domain判斷的能力,也即鑑別unknown的能力(mnint/CIFAR-10對比,其實差異蠻大的,233),至少可以吊打softmax
Reference:
1.論文:地址
2.源碼: 地址
3.論文解讀:地址