Article Analysis(AA): A Simple Framework for Contrastive Learning of Visual Representations

本文爲讀文章筆記,受所學所知限制,如有出錯,恭請指正。


A Simple Framework for Contrastive Learning of Visual Representations
作者: Ting Chen, Simon Kornblith, Mohammad Norouzi, Geoffrey Hinton

本文提出一種簡潔有效的設計的無監督設計,並且以7%的margin刷新了SOTA。

摘要直譯:這篇文章提出了SimCLR, 一種簡單的、用於視覺表徵對比學習的框架。作者們簡化了最近剛提出的對比自監督學習算法,並且不需要特別的架構或者J記憶庫。爲了探究是什麼使得對比預測任務能夠學習到游泳的表徵,作者們系統地研究了該框架的大部分組件。作者們展示出(1)數據增強的組成在定義高效預測任務中具有關鍵的作用,(2)在表徵和對比損失之間引入了一種可學習的非線性變換,該變換能夠實質性地提高學習到的表徵的質量,(3)相對於監督學習,對比學習能夠從更大的batch size和更多的訓練中獲益。通過組合以上要點,在ImageNet上,作者們的方法能夠大大的超過之前用於自監督和半監督的方法。一個用SimCLR學習的自監督表徵的線性分類器能夠達到76.5%的top-1精度,這是7%的相對提升,超過之前的SOTA, 且與監督模型ResNet-50的性能無異。僅僅1%的標籤量用於微調,就能達到85.8%的top-5精度,以少100倍的標籤量超過AlexNet。


核心分析

對比學習框架,如下圖
在這裏插入圖片描述
該框架有四個主要模塊:
1, 隨機數據增強模塊,它能夠隨機地變換任何給定的數據樣本,即生成同一樣本的兩個相關表徵,xi^\hat{x_i}xj^\hat{x_j},也就是一個正樣本對,如上圖。在文章中,順序應用了3個簡單的增強方式,隨機剪裁之後,Resize到同一尺寸,接着是隨機顏色擾動,隨機高斯模糊。特別的是,隨機剪裁和顏色擾動的組合對獲得好性能至關重要。
2,用於從增強後的數據樣本中提取表徵向量的神經網絡基礎編碼器(base encoder)f()f()。該框架能夠無限制的適用不同的網絡框架。文章中,作者們採用簡單通用的ResNet來計算hih_i, 即hi=f(xi^)=ResNet(xi^)h_i=f(\hat{x_i})=ResNet(\hat{x_i}), 其中hiRdh_i \in R^d是均值池化後的輸出。
3, 神經網絡映射頭(projection head)g()g(),用來將表徵映射到對比損失應用的空間。文章中用一個隱藏層的MLP來計算ziz_izi=g(hi)=W(2)σ(W(1)hi)z_i=g(h_i)=W^{(2)}\sigma(W^{(1)}h_i),其中σ\sigma是一個ReLU。作者認爲在ziz_i上定義對比損失比在hih_i上更好。
4, 對比損失函數,用於對比預測任務。給定一個包含正樣本對xi^\hat{x_i}xj^\hat{x_j}的數據集xk^{\hat{x_k}},對比預測任務目標是,給定xi^\hat{x_i}後,在{xk^}ki\{\hat{x_k}\}_{k \neq i}中識別xj^\hat{x_j}

給定一個minibatch NN的樣本,在該批增強後的樣本上定義對比預測任務,則有2N2N個數據點。注意並沒有採樣負樣例。給定一對正樣例,同批次中其他2(N1)2(N-1)的增強樣例作爲負樣例。

兩個向量vvuu之間的餘弦相似度,即 sim(u,v)=uTv/uvsim(u, v)=u^{T}v / ||u|| \cdot ||v||,那麼對一對正樣本(i,j)(i, j)有損失函數
li,j=logexp(sim(zi,zj)/τ)k=12N1[ki]exp(sim(zi,zk)/τ)l_{i, j}= - log \frac{exp(sim(z_i, z_j)/\tau)}{\sum^{2N}_{k=1} 1_{[k \neq i]} exp(sim(z_i, z_k)/\tau)}
其中1[ki]{0,1}1_{[k \neq i]} \in \{0, 1\}是指示函數, τ\tau是一個溫度參數。最終的損失需要計算批次中所有的正樣例對,即(i,j)(i, j), (j,i)(j, i)。文章中, 作者們稱以上爲NT-Xent(the normalized temperature-scaled cross entropy loss)。
在這裏插入圖片描述
以上是文章核心內容的說明。消融實驗非常值得看,這裏不在列出;放個結果圖
在這裏插入圖片描述
其中,(1×,2×,4×)(1 \times, 2\times, 4\times)指的是ResNet-50中3個不同的隱藏層寬度,見文章第六部分第二行。


參考文獻:
[1] https://arxiv.org/pdf/2002.05709.pdf 文章源地址。
[2] http://xxx.itp.ac.cn/pdf/2002.05709.pdf 國內鏡像地址。

發佈了18 篇原創文章 · 獲贊 9 · 訪問量 2萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章