【乾貨】虛擬對抗訓練簡介

虛擬對抗訓練是一種有效的正則化技術,在監督學習,半監督學習和無監督聚類方面取得了良好的效果。

虛擬對抗訓練已用於:

  1. 提高監督學習績效
  2. 半監督學習
  3. 深度無監督聚類

有幾種正則化技術可以防止過度擬合,並有助於模型更好地概括出看不見的例子。正則化有助於模型參數更少地依賴於訓練數據。兩種最常用的正則化技術是Dropout和L1 / L2正則化。

在L1 / L2正則化中,我們添加一個損失項,試圖減小權重矩陣的L1範數或L2範數。較小的權重值將導致更簡單的模型,這些模型不易過度擬合。

在Dropout中,我們在訓練時隨機忽略一些神經元。這使得網絡對噪聲和輸入變化更加魯棒。

在所提到的兩種技術中,都沒有考慮輸入數據分佈。

局部分佈平滑

局部分佈平滑度(LDS)可以定義爲模型輸出分佈相對於輸入的平滑度。我們不希望模型對輸入中的小擾動敏感。我們可以說,對於模型輸入的微小變化,模型輸出不應該有大的變化。

在LDS正則化中,模型分佈的平滑性得到獎勵。它也是網絡上參數的不變量,僅取決於模型輸出。具有平滑的模型分佈應該有助於模型更好地推廣,因爲模型將爲看不見的數據點提供類似的輸出,這些數據點接近訓練集中的數據點。一些研究表明,使模型對小的隨機擾動具有魯棒性對於正則化是有效的。

LDS正則化的一種簡單方法是通過在實際數據點上應用小的隨機擾動來生成人工數據點。之後,鼓勵模型爲真實和擾動的數據點提供類似的輸出。領域知識也可用於產生更好的擾動。例如,如果輸入是圖像,則可以使用各種圖像增強技術,例如翻轉,旋轉,變換顏色。

輸入數據轉換的示例

虛擬對抗訓練

虛擬對抗訓練是一種有效的局部分配平滑性技術。採用成對的數據點,這些數據點在輸入空間中非常接近,但在模型輸出空間中非常接近。然後訓練模型以使它們的輸出彼此接近。爲此,採用給定的輸入並且發現擾動,模型給出非常不同的輸出。然後,模型因擾動而對靈敏度進行處罰。

虛擬對抗訓練的關鍵步驟是:

  1. 從輸入數據點x開始
  2. 通過添加小的擾動r來變換x,因此變換的數據點將是T(x)= x + r
  3. 擾動r應該在對側方向 - 擾動輸入T(x)的模型輸出應該與非擾動輸入的輸出不同。特別是,兩個輸出分佈之間的KL差異應該是最大的,同時確保r的L2範數很小。從所有的擾動r,讓r v-adv成爲對抗方向的擾動。

  1. 在找到對抗擾動和變換輸入之後,更新模型的權重,使得KL散度最小化。這將使模型對不同的擾動具有魯棒性。通過梯度下降最小化以下損失:

在虛擬對抗訓練期間,模型對於不同的輸入擾動變得更加魯棒。隨着模型變得更加穩健,產生擾動變得更加困難並且觀察到損失的下降。

可以將此方法視爲與生成性對抗網絡類似。但是有幾個不同之處:

  1. 不是讓發生器欺騙鑑別器,而是在輸入中添加一個小擾動,以欺騙模型,認爲它們是兩個截然不同的輸入。
  2. 不是區分僞造和真實,而是使用模型輸出之間的KL分歧。在訓練模型時(類似於訓練鑑別器),我們最小化KL分歧。

虛擬對抗訓練可以被認爲是一種有效的數據增強技術,我們不需要先前的領域知識。這可以應用於所有類型的輸入分佈,因此對於真正的“無監督學習”是有用的。

虛擬對抗訓練與對抗訓練有何不同?

在對抗訓練中,標籤也用於產生對抗性擾動。產生擾動使得分類器的預測標籤y'變得與實際標籤y不同。

在虛擬對抗訓練中,不使用標籤信息,僅使用模型輸出生成擾動。產生擾動使得擾動輸入的輸出不同於原始輸入的模型輸出(與地面實況標籤相反)。

實施虛擬對抗訓練

現在我們將使用Tensorflow和Keras實現基本的虛擬對抗訓練。完整的代碼可以在這裏找到

首先,在Keras中定義神經網絡

network = Sequential()

network.add( Dense(100 ,activation='relu' , input_shape=(2,)))

network.add( Dense( 2 ))

定義model_input中,logits p_logit通過將輸入到網絡和概率得分p通過在logits施加SOFTMAX活化。

model_input = Input((2,))

p_logit = network( model_input )

p = Activation('softmax')( p_logit )

爲了產生對抗性擾動,從隨機擾動開始r並使其成爲單位範數。

r = tf.random_normal(shape=tf.shape( model_input ))

r = make_unit_norm( r )

擾動輸入的輸出logits將是 p_logit_r

p_logit_r = network( model_input + 10*r )

現在計算來自輸入和擾動輸入的log的KL偏差。

kl = tf.reduce_mean(compute_kld( p_logit , p_logit_r ))

爲了獲得對抗性擾動,我們需要r使KL-發散最大化。因此,採取kl相對於的梯度 r。對抗性擾動將是梯度。我們使用該stop_gradient函數是因爲我們希望r_vadv在反向傳播時保持固定。

grad_kl = tf.gradients( kl , [r ])[0]

最後,規範化範數對抗性擾動。我們將範數設定r_vadv爲一個較小的值,即我們想要沿着對抗方向前進的距離。

r_vadv = make_unit_norm( r_vadv )/3.0

現在我們有對抗擾動r_vadv,模型給出了非常大的輸出差異。我們需要在模型中添加一個損失,這會損害模型,使其具有與原始輸入和擾動輸入的輸出相比具有較大KL偏差的模型。

p_logit_r_adv = network( model_input + r_vadv )

vat_loss = tf.reduce_mean(compute_kld( tf.stop_gradient(p_logit), p_logit_r_adv ))

最後,構建模型並附加vat_loss。

model_vat = Model(model_input , p )

model_vat.add_loss( vat_loss )

model_vat.compile( 'sgd' , 'categorical_crossentropy' , metrics=['accuracy'])

現在讓我們使用一些合成數據來訓練和測試模型。該數據集是二維的,有兩個類。1類數據點位於外環中,2類數據點位於內環中。我們每班僅使用8個數據點進行培訓,並使用1000個數據點進行測試。

合成數據集在2D平面上的圖

讓我們通過調用fit函數來訓練模型。

model.fit( X_train , Y_train_cat )

可視化模型輸出

現在,讓我們可視化模型的輸出空間以及訓練和測試數據。

模擬決策邊界與虛擬對抗訓練

對於這個示例數據集,非常明顯的是,具有虛擬對抗訓練的模型已經更好地推廣並且其決策邊界也在於測試數據的邊界。

沒有虛擬對抗訓練的模型決策邊界

對於沒有虛擬對抗訓練的模型,我們看到訓練數據點有些過度擬合。在這種情況下,決策邊界不好並且與其他類重疊。

虛擬對抗訓練的應用

虛擬對抗訓練已經在半監督學習和無監督學習中的各種應用中顯示出令人難以置信的結果。

半監督學習的增值稅:虛擬對抗性訓練在半監督學習中表現出良好的效果。在這裏,我們有大量未標記的數據點和一些標記的數據點。應用vat_loss未標記的集合和標記集合上的監督損失可以提高測試精度。作者表明該方法優於其他幾種半監督學習方法。您可以在此處閱讀更多內容。

虛擬對抗梯形網絡:梯形網絡已經顯示出半監督分類的有希望的結果。在那裏,在每個輸入層,添加隨機噪聲並且訓練解碼器以對每層的輸出進行去同化。在虛擬對抗梯形網絡中,不使用隨機噪聲,而是使用對抗性噪聲。您可以在此處閱讀更多內容。

使用自增強訓練的無監督聚類:這裏的目標是在不使用任何標記樣本的情況下將數據點聚類在固定數量的聚類中。 規範化信息最大化是一種用於無監督聚類的技術。這裏輸入和模型輸出之間的相互信息被最大化。IMSAT通過添加虛擬對抗訓練擴展了該方法。隨着互信息的丟失,作者應用了vat_loss。在添加虛擬對抗訓練後,它們顯示出很大的改進。您可以在論文和我之前的博客文章中閱讀更多內容。

結論

在這篇文章中,我們討論了一種稱爲虛擬對抗訓練的有效正則化技術。我們還使用Tensorflow和Keras進行實施。我們觀察到,當訓練樣本很少時,含增值稅的模型表現更好。我們還討論了使用虛擬對抗訓練的各種其他作品。如果您有任何疑問或想要建議任何更改,請隨時與我聯繫或在下面寫評論。

相關源碼關注微信公衆號:“圖像算法”或者微信搜索賬號imalg_cn 獲取

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章