孿生網絡入門(上) Siamese Net及其損失函數

最近在多個關鍵詞(小數據集,無監督半監督,圖像分割,SOTA模型)的範疇內,都看到了這樣的一個概念,孿生網絡,所以今天有空大概翻看了一下相關的經典論文和博文,之後做了一個簡單的案例來強化理解。如果需要交流的話歡迎聯繫我,WX:cyx645016617

所以這個孿生網絡入門,我想着分成上下兩篇,上篇也就是這一篇講解模型理論、基礎知識和孿生網絡獨特的損失函數;下篇講解一下如何用代碼來復線一個簡單的孿生網絡。

1 名字的由來

孿生網絡的別名就會死Siamese Net,而Siam是古代泰國的稱呼,所以Siamese其實是“泰國人”的古代的稱呼。 爲什麼Siamese現在在英文中是“孿生”“連體”的意思呢?這源自一個典故:

十九世紀泰國出生了一對連體嬰兒,當時的醫學技術無法使兩人分離出來,於是兩人頑強地生活了一生,1829年被英國商人發現,進入馬戲團,在全世界各地表演,1839年他們訪問美國北卡羅萊那州後來成爲“玲玲馬戲團” 的臺柱,最後成爲美國公民。1843年4月13日跟英國一對姐妹結婚,恩生了10個小孩,昌生了12個,姐妹吵架時,兄弟就要輪流到每個老婆家住三天。1874年恩因肺病去世,另一位不久也去世,兩人均於63歲離開人間。兩人的肝至今仍保存在費城的馬特博物館內。從此之後“暹羅雙胞胎”(Siamese twins)就成了連體人的代名詞,也因爲這對雙胞胎讓全世界都重視到這項特殊疾病。

2 模型結構

這個圖有這幾個點來理解:

  • 其中的Network1和Network2按照專業的話來說就是共享權制,說白了這兩個網絡其實就是一個網絡,在代碼中就構建一個網絡就行了;
  • 一般的任務,每一個樣本經過模型得到一個模型的pred,然後這個pred和ground truth進行損失函數的計算,然後得到梯度;這個孿生網絡則改變了這種結構,假設是圖片分類的任務,把圖片A輸入到模型中得到了一個輸出pred1,然後我再把圖片B輸入到模型中,得到了另外一個輸出pred2,然後我這個損失函數是從pred1和pred2之間計算出來的。 就是一般情況下,模型運行一次,給出一個loss,但是在siamese net中,模型要運行兩次才能得到一個loss。
  • 我個人感覺,一般的任務像是衡量一種絕對的距離,樣本到標籤的一個距離;但是孿生網絡衡量的是樣本到樣本之間的一個距離。

2.1 孿生網絡的用途

Siamese net衡量的是兩個輸入的關係,也就是兩個樣本相似還是不相似。

有這樣的一個任務,在NIPS上,在1993年發表了文章《Signature Verification using a ‘Siamese’ Time Delay Neural Network》用於美國支票上的簽名驗證,檢驗支票上的簽名和銀行預留的簽名是否一致。當時論文中就已經用卷積網絡來做驗證了...當時我還沒出生。

之後,2010年Hinton在ICML上發表了《Rectified Linear Units Improve Restricted Boltzmann Machines》,用來做人臉驗證,效果很好。輸入就是兩個人臉,輸出就是same or different

可想而知,孿生網絡可以做分類任務。在我看來,孿生網絡不是一種網絡結構,不是resnet那種的網絡結構,而是一種網絡的框架,我可以把resnet當成孿生網絡的主幹網絡這樣的

既然孿生網絡的backbone(我們暫且這樣叫,應該可以理解的把)可以是CNN,那麼也自然可以是LSTM,這樣可以實現詞彙的語義的相似度分析

之前Kaggle上有一個question pair的比賽,衡量兩個問題是否提問的是同一個問題這樣的比賽,TOP1的方案就是這個孿生網絡的結構Siamese net。

後來好像還有基於Siamese網絡的視覺跟蹤算法,這個我還沒有了解,以後有機會的話我看一看這個論文。《Fully-convolutional siamese networks for object tracking》。先挖一個坑。

2.2 僞孿生網絡

問題來了,孿生網絡中看似兩個網絡,實則共享權制爲一個網絡,假設我們真的給他弄兩個網絡,那樣不就可以一個是LSTM,一個CNN實現不同模態的相似度比較了?

沒錯,這個叫做pseudo-siamese network 僞孿生網絡。一個輸入是文字,一個輸入是圖片,判斷文字描述是否是圖片內容;一個是短標題,一個是長文章,判斷文章內容是否是標題。(高中語文作文常年跑題選手的救星,以後給老師說這個算法說我的文章沒有跑題,您要不再看看?老師會打死我嗎)

不過本文和下一篇的代碼都是以siamese network爲核心,backbone也以CNN卷積網絡和圖像展開。

2.3 三胞胎

既然有了二胞胎的網絡,當然也有三胞胎,叫做Triplet network《Deep metric learning using Triplet network》。據說效果已經好過Siamese network了,不知道有沒有四胞胎和五胞胎。

3 損失函數

分類任務常規使用softmax加上交叉熵,但是有人提出了,這種方法訓練的模型,在“類間”區分性上表現的並不好,使用對抗樣本攻擊就立刻不行了。後續有空講解一下對抗樣本攻擊,再挖個坑。 簡單的說就是,假設是人臉識別,那麼每個人就是一個類別,那麼你讓一個模型做一個幾千分類的任務,每一個類別的數據又很少的情況下,想想也會感覺到這個訓練的難度。

針對這樣的問題,孿生網絡有兩個損失函數比較近經典:

  • Contrastive Loss
  • Triplte Loss

3.1 Contrastive Loss

  • 提出論文:《Dimensionality Reduction by Learning an Invariant Mapping》
    現在我們已知:
  • 圖片1 經過模型 得到pred1
  • 圖片2 經過模型 得到pred2
  • pred1和pred2計算得到loss

論文中給出了這樣的一個計算公式:

首先呢,這個經過模型得到的pred1和pred2都是向量,過程相當於圖片經過CNN提取特徵,然後得到了一個隱含向量,是一個Encoder的感覺。

然後計算這兩個向量的歐氏距離,這個距離(如果模型訓練的正確的話),就可以反應兩個輸入圖像的相關性。我們每次輸入兩個圖片,我們需要事先確定這兩個圖像是一類的,還是不同的,這個類似一個標籤,也就是上圖公式中的Y。如果是一類的,那麼Y爲0,如果不是,Y=1

類似於二值交叉熵損失函數,我們需要注意的是:

  • Y=0的時候,損失爲:\((1-Y)L_S(D_W^i)\)
  • Y=1的時候,損失爲:\(YL_D(D_W^i)\).
  • 其中論文中\(L_D,L_S\)是常數,論文中默認取0.5
  • i是一個次方的含義,論文中和常用的contrastive loss中,都是默認i=2,也就是歐氏距離的平方。
  • 對於類別是1(different類別的),我們自然是希望pred1和pred2的歐氏距離越大越好。那麼這個大到什麼程度是個頭呢?損失函數是往小的方向移動,那麼需要做什麼呢?增加一個margin,當作最大的距離。如果pred1和pred2的距離大於margin,那麼就認爲這兩個樣本距離足夠大,就當其的損失爲0。所以寫的方法就是:\(max(margin-distance,0)\).
  • 上圖中的W我理解爲神經網絡的weight,然後\(\vec X_1\),表示要輸入的原圖片。

所以損失函數就變成這個樣子:

總結一下,這裏面需要注意的應該就是對於different的兩個圖片,需要設置一個margin,然後小於margin的計算損失,大於margin的損失爲0.

3.2 Contrastive Loss pytorch

# Custom Contrastive Loss
class ContrastiveLoss(torch.nn.Module):
    """
    Contrastive loss function.
    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    """

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +     # calmp夾斷用法
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))     
 

        return loss_contrastive

其中唯一需要談一下的可能就是torch.nn.functional.pariwise_distance,
這個就是計算對應元素的歐氏距離,舉個例子:

import torch
import torch.nn.functional as F
a = torch.Tensor([[1,2],[3,4]])
b = torch.Tensor([[10,20],[30,40]])
F.pairwise_distance(a,b)

輸出爲:

然後看一下這個數字是不是歐氏距離:

沒問題的啊

3.3 Triplte Loss

  • 提出論文:《FaceNet: A Unified Embedding for Face Recognition and Clustering》

這個論文提出了FactNet,然後使用了Triplte Loss。Triplet Loss即三元組損失,我們詳細來介紹一下。

  • Triplet Loss定義:最小化錨點和具有相同身份的正樣本之間的距離,最小化錨點和具有不同身份的負樣本之間的距離。這個其實應該是三胞胎網絡的損失函數,同時輸入三個樣本,一個圖片,然後一個same類別的圖片和一個different圖片。
  • Triplet Loss的目標:Triplet Loss的目標是使得相同標籤的特徵在空間位置上儘量靠近,同時不同標籤的特徵在空間位置上儘量遠離,同時爲了不讓樣本的特徵聚合到一個非常小的空間中要求對於同一類的兩個正例和一個負例,負例應該比正例的距離至少遠margin。如下圖所示:

這個的話我們要如何構建損失函數呢?已知我們想要的:

  • 讓anchor和positive得到的向量的歐氏距離越小越好;
  • 讓anchor和negative得到的向量的歐氏距離越大越好;

所以期望下面這個公式成立:

  • 簡單的說就是anchor和positive的距離要比anchor和negative的距離小,而且這個差距要至少要大於\(\alpha\)個人的思考是,這裏的T,是三元組的集合。對於一個數據集,往往可以構建出非常多的三元組,因此我個人感覺這種任務一般用在類別多,數據量較少的任務中,不然三元組數量爆炸了

3.4 Triplte Loss keras

這裏有一個keras的triplte loss的代碼

def triplet_loss(y_true, y_pred):
        """
        Triplet Loss的損失函數
        """

        anc, pos, neg = y_pred[:, 0:128], y_pred[:, 128:256], y_pred[:, 256:]

        # 歐式距離
        pos_dist = K.sum(K.square(anc - pos), axis=-1, keepdims=True)
        neg_dist = K.sum(K.square(anc - neg), axis=-1, keepdims=True)
        basic_loss = pos_dist - neg_dist + TripletModel.MARGIN

        loss = K.maximum(basic_loss, 0.0)

        print "[INFO] model - triplet_loss shape: %s" % str(loss.shape)
        return loss

參考文獻:

[1] Momentum Contrast for Unsupervised Visual Representation Learning, 2019, Kaiming He Haoqi Fan Yuxin Wu Saining Xie Ross Girshick

[2] Dimensionality Reduction by Learning an Invariant Mapping, 2006, Raia Hadsell, Sumit Chopra, Yann LeCun

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章