【已解決】使用keras對resnet, inception3進行fine-tune出現訓練集準確率很高但驗證集很低的問題(BN)

最近用keras跑基於resnet50,inception3的一些遷移學習的實驗,遇到一些問題。通過查看github和博客發現是由於BN層導致的,國外已經有人總結並提了一個PR(雖然並沒有被merge到Keras官方庫中),並寫了一篇博客,也看到知乎有人翻譯了一遍:Keras的BN你真的凍結對了嗎

當保存模型後再加載模型去預測時發現與直接預測結果不一致也可能是BN層的問題。

總結:

  • keras中通常用trainable這個參數來控制某一層的權重是否更新,例如trainable可以控制BN中的和是否變化。

  • TF爲後端時,BN有一個參數是training,控制歸一化時用的是當前Batch的均值和方差(訓練模式)還是移動均值和方差(測試模式),這個參數由Keras的K.learning_phase控制。若只設置trainable是不會影響BN的training參數。

  • 凍結時某一層時,我們希望這一層的狀態和預訓練模型中的狀態一致

  • 我們通常希望訓練和測試時網絡中的配置一致,但BN訓練和測試時的配置是不一樣的,而frozen這個行爲放大了這種不一致,導致精度下降。訓練時用了新數據集的均值和方差去做歸一化,測試時用了舊數據集的移動均值和方差去做歸一化

  • 爲了讓訓練和測試儘量一致,避免精度下降,有兩種方案,一種是在測試時也用舊數據集的移動均值和方差

  • 另一種方案是在訓練時也只用舊數據集的移動均值和方差,這是Keras作者fchollet在GitHub issue裏回覆的方案:在定義模型時,手動將training參數設爲False(可以通過顯式設置BN的training參數,或者通過設置learning_phase來隱式改變training參數),我覺得其實這種workaround還是挺好用的,而且也更符合frozen的意圖,即:

顯式設置:

   x = BatchNormalization()(y, training=False)

隱式設置:

    # Set up inference-mode base
    K.set_learning_phase(0)
    inputs = Input(...)
    x = layer1(...)(inputs)
    x = layer2(...)(x)
    ...
    x = layerN(...)(x)
    
    # Add training-mode layers
    K.set_learning_phase(1)
    x = layerNp1(...)(x)
    x = layerNp2(...)(x)

不可否認的是,默認的Frozen的BN的行爲在遷移學習中確實是有training這個坑存在的,個人認爲fchollet的修復方法更簡單一點,並且這種方式達到的效果和使用預訓練網絡提取特徵,單獨訓練分類層達到的效果是一致的,當你真的想要凍結BN層的時候,這種方式更符合凍結的這個動機;但在測試時使用新數據集的移動均值和方差一定程度上也是一種domain adaption。

譯文:
雖然Keras節省了我們很多編碼時間,但Keras中BN層的默認行爲非常怪異,坑了我(此處及後續的“我”均指原文作者)很多次。Keras的默認行爲隨着時間發生過許多的變化,但仍然有很多問題以至於現在Keras的GitHub上還掛着幾個相關的issue。在這篇文章中,我會構建一個案例來說明爲什麼Keras的BN層對遷移學習並不友好,並給出對Keras BN層的一個修復補丁,以及修復後的實驗效果。

1. Introduction

這一節我會簡要介紹遷移學習和BN層,以及learning_phase的工作原理,Keras BN層在各個版本中的變化。如果你已經瞭解過這些知識,可以直接跳到第二節(譯者注:1.3和1.4跟這個問題還是比較相關的,不全是背景)。

1.1 遷移學習在深度學習中非常重要

深度學習在過去廣受詬病,原因之一就是它需要太多的訓練數據了。解決這個限制的方法之一就是遷移學習。

假設你現在要訓練一個分類器來解決貓狗二分類問題,其實並不需要幾百萬張貓貓狗狗的圖片。你可以只對預訓練模型頂部的幾層卷積層進行微調。因爲預訓練模型是用圖像數據訓練的,底層卷積層可以識別線條,邊緣或者其他有用的模式作爲特徵使用,所以可以用預訓練模型的權重作爲一個很好的初始化值,或者只對模型的一部分用自己數據進行訓練。

在這裏插入圖片描述
Keras包含多種預訓練模型,並且很容易Fine-tune,更多細節可以查閱Keras官方文檔

1.2 Batch Normalization是個啥

BN在2014年由Loffe和Szegedy提出,通過將前一層的輸出進行標準化解決梯度消失問題,並減小了訓練達到收斂所需的迭代次數,從而減少訓練時間,使得訓練更深的網絡成爲可能。具體原理請看原論文,簡單來說,BN將每一層的輸入減去其在Batch中的均值,除以它的標準差,得到標準化的輸入,此外,BN也會爲每個單元學習兩個因子和來還原輸入。從下圖可以看到加了BN之後Loss下降更快,最後能達到的效果也更好。
在這裏插入圖片描述

1.3 Keras中的learning_phase是啥

網絡中有些層在訓練時和推導時的行爲是不同的。最重要的兩個例子就是BN和Dropout層對BN層,訓練時我們需要用mini batch的均值和方差來縮放輸入。在推導時,我們用訓練時統計到的累計均值和方差對推導的mini batch進行縮放

Keras用learning_phase機制來告訴模型當前的所處的模式。假如用戶沒有手工指定的話,使用fit()時,網絡默認將learning_phase設爲1,表示訓練模式。在預測時,比如調用predict()和evaluate()方法或者在fit()的驗證步驟中,網絡將learning_phase設爲0,表示測試模式。用戶可以靜態地,在model或tensor添加到一個graph中之前,將learning_phase設爲某個值(雖然官方不推薦手動設置),設置後,learning_phase就不可以修改了。

1.4 不同版本中的Keras是如何實現BN的

Keras中的BN訓練時統計當前Batch的均值和方差進行歸一化,並且使用移動平均法累計均值和方差,給測試集用於歸一化。

Keras中BN的行爲變過幾次,但最重要的變更發生在2.1.3這個版本。2.1.3之前,當BN被凍結時(trainable=False),它仍然會更新mini batch的移動均值和方差,並用於測試,造成用戶的困擾(一副沒有凍結住的樣子)。

這種設計是錯誤的。考慮Conv1-Bn-Conv2-Conv3這樣的結構,如果BN層被凍結住了,應該無事發生纔對。當Conv2處於凍結狀態時,如果我們部分更新了BN,那麼Conv2不能適應更新過的mini-batch的移動均值和方差,導致錯誤率上升

在2.1.3及之後,當BN層被設爲trainable=False時,Keras中不再更新mini batch的移動均值和方差,測試時使用的是預訓練模型中的移動均值和方差,從而達到凍結的效果, But is that enough? Not if you are using Transfer Learning.

2. 問題描述與解決方案

我會介紹問題的根源以及解決方案(一個Keras補丁)的技術實現。同時我也會提供一些樣例來說明打補丁前後模型的準確率變化。

2.1 問題描述

2.1.3版本後,當Keras中BN層凍結時,在訓練中會用mini batch的均值和方差統計值以執行歸一化。我認爲更好的方式應該是使用訓練中得到的移動均值和方差(譯者注:這樣不就退回2.1.3之前的做法了)。原因和2.1.3的修復原因相同,由於凍結的BN的後續層沒有得到正確的訓練,使用mini batch的均值和方差統計值會導致較差的結果。

假設你沒有足夠的數據訓練一個視覺模型,你準備用一個預訓練Keras模型來Fine-tune。但你沒法保證新數據集在每一層的均值和方差與舊數據集的統計值的相似性。注意哦,在當前的版本中,不管你的BN有沒有凍結,訓練時都會用mini-batch的均值和方差統計值進行批歸一化,而在測試時你也會用移動均值方差進行歸一化。因此,如果你凍結了底層並微調頂層,頂層均值和方差會偏向新數據集,而推導時,底層會使用舊數據集的統計值進行歸一化,導致頂層接收到不同程度的歸一化的數據。
在這裏插入圖片描述
如上圖所示,假設我們從Conv K+1層開始微調模型,凍結左邊1到k層。訓練中,1到K層中的BN層會用訓練集的mini batch統計值來做歸一化然而,由於每個BN的均值和方差與舊數據集不一定接近,在Relu處的丟棄的數據量與舊數據集會有很大區別,導致後續K+1層接收到的輸入和舊數據集的輸入範圍差別很大,後續K+1層的初始權重不能恰當處理這種輸入,導致精度下降。儘管網絡在訓練中可以通過對K+1層的權重調節來適應這種變化,但在測試模式下,Keras會用預訓練數據集的均值和方差,改變K+1層的輸入分佈,導致較差的結果。

2.2 如何檢查你是否受到了這個問題的影響

  • 分別將learning_phase這個變量設置爲1或0進行預測,如果結果有顯著的差別,說明你中招了。不過learning_phase這個參數通常不建議手工指定,learning_phase不會改變已經編譯後的模型的狀態,所以最好是新建一個乾淨的session,在定義graph中的變量之前指定learning_phase。

  • 檢查AUC和ACC,如果acc只有50%但auc接近1(並且測試和訓練表現有明顯不同),很可能是BN迷之縮放的鍋。類似的,在迴歸問題上你可以比較MSE和Spearman‘s correlation來檢查。

2.3 如何修復

如果BN在測試時真的鎖住了,這個問題就能真正解決。實現上,需要用trainable這個標籤來真正控制BN的行爲,而不僅是用learning_phase來控制。具體實現在GitHub上。

主要是通過安裝補丁:作者提供了三個版本的補丁,安裝自己需要的版本就可以

pip install -U --force-reinstall --no-dependencies git+https://github.com/datumbox/keras@bugfix/trainable_bn

或者

pip install -U --force-reinstall --no-dependencies git+https://github.com/datumbox/keras@fork/keras2.2.4

用了這個補丁之後,BN凍結後,在訓練時它不會使用mini batch均值方差統計值進行歸一化,而會使用在訓練中學習到的統計值,避免歸一化的突變導致準確率的下降**。如果BN沒有凍結,它也會繼續使用訓練集中得到的統計值。**

原文:
By applying the above fix, when a BN layer is frozen it will no longer use the mini-batch statistics but instead use the ones learned during training. As a result, there will be no discrepancy between training and test modes which leads to increased accuracy. Obviously when the BN layer is not frozen, it will continue using the mini-batch statistics during training.

2.4 評估這個補丁的影響

雖然這個補丁是最近才寫好的,但其中的思想已經在各種各樣的workaround中驗證過了。這些workaround包括:將模型分成兩部分,一部分凍結,一部分不凍結,凍結部分只過一遍提取特徵,訓練時只訓練不凍結的部分。爲了增加說服力,我會給出一些例子來展示這個補丁的真實影響。

  • 我會用一小塊數據來刻意過擬合模型,用相同的數據來訓練和驗證模型,那麼在訓練集和驗證集上都應該達到接近100%的準確率。
  • 如果驗證的準確率低於訓練準確率,說明當前的BN實現在推導中是有問題的。
  • 預處理在generator之外進行,因爲keras2.1.5中有一個相關的bug,在2.1.6中修復了。
  • 在推導時使用不同的learning_phase設置,如果兩種設置下準確率不同,說明確實中招了。

代碼如下:

import numpy as np
from keras.datasets import cifar10
from scipy.misc import imresize
 
from keras.preprocessing.image import ImageDataGenerator
from keras.applications.resnet50 import ResNet50, preprocess_input
from keras.models import Model, load_model
from keras.layers import Dense, Flatten
from keras import backend as K
 
 
seed = 42
epochs = 10
records_per_class = 100
 
# We take only 2 classes from CIFAR10 and a very small sample to intentionally overfit the model.
# We will also use the same data for train/test and expect that Keras will give the same accuracy.
(x, y), _ = cifar10.load_data()
 
def filter_resize(category):
   # We do the preprocessing here instead in the Generator to get around a bug on Keras 2.1.5.
   return [preprocess_input(imresize(img, (224,224)).astype('float')) for img in x[y.flatten()==category][:records_per_class]]
 
x = np.stack(filter_resize(3)+filter_resize(5))
records_per_class = x.shape[0] // 2
y = np.array([[1,0]]*records_per_class + [[0,1]]*records_per_class)
 
 
# We will use a pre-trained model and finetune the top layers.
np.random.seed(seed)
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
l = Flatten()(base_model.output)
predictions = Dense(2, activation='softmax')(l)
model = Model(inputs=base_model.input, outputs=predictions)
 
for layer in model.layers[:140]:
   layer.trainable = False
 
for layer in model.layers[140:]:
   layer.trainable = True
 
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit_generator(ImageDataGenerator().flow(x, y, seed=42), epochs=epochs, validation_data=ImageDataGenerator().flow(x, y, seed=42))
 
# Store the model on disk
model.save('tmp.h5')
 
 
# In every test we will clear the session and reload the model to force Learning_Phase values to change.
print('DYNAMIC LEARNING_PHASE')
K.clear_session()
model = load_model('tmp.h5')
# This accuracy should match exactly the one of the validation set on the last iteration.
print(model.evaluate_generator(ImageDataGenerator().flow(x, y, seed=42)))
 
 
print('STATIC LEARNING_PHASE = 0')
K.clear_session()
K.set_learning_phase(0)
model = load_model('tmp.h5')
# Again the accuracy should match the above.
print(model.evaluate_generator(ImageDataGenerator().flow(x, y, seed=42)))
 
 
print('STATIC LEARNING_PHASE = 1')
K.clear_session()
K.set_learning_phase(1)
model = load_model('tmp.h5')
# The accuracy will be close to the one of the training set on the last iteration.
print(model.evaluate_generator(ImageDataGenerator().flow(x, y, seed=42)))

輸出如下:

Epoch 10/10
1/7 [===>..........................] - ETA: 3s - loss: 0.0354 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0381 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0354 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0828 - acc: 0.9688
5/7 [====================>.........] - ETA: 1s - loss: 0.0791 - acc: 0.9750
6/7 [========================>.....] - ETA: 0s - loss: 0.0794 - acc: 0.9792
7/7 [==============================] - 8s 1s/step - loss: 0.0704 - acc: 0.9838 - val_loss: 0.3615 - val_acc: 0.8600

DYNAMIC LEARNING_PHASE
[0.3614931714534759, 0.86]

STATIC LEARNING_PHASE = 0
[0.3614931714534759, 0.86]

STATIC LEARNING_PHASE = 1
[0.025861846953630446, 1.0]

如上文所述,驗證集準確率確實要差一些。

訓練完成後,我們做了三個實驗,DYNAMIC LEARNING_PHASE是默認操作,由Keras內部機制動態決定learning_phase,static兩種是手工指定learning_phase,分爲設爲0和1.當learning_phase設爲1時,驗證集的效果提升了,因爲模型正是使用訓練集的均值和方差統計值來訓練的,而這些統計值與凍結的BN中存儲的值不同,凍結的BN中存儲的是預訓練數據集的均值和方差,不會在訓練中更新,會在測試中使用。這種BN的行爲不一致性導致了推導時準確率下降。

加了補丁後的效果:

Epoch 10/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0251 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.0228 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0217 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0249 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0244 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0239 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0290 - acc: 1.0000 - val_loss: 0.0127 - val_acc: 1.0000
 
DYNAMIC LEARNING_PHASE
[0.012697912137955427, 1.0]
 
STATIC LEARNING_PHASE = 0
[0.012697912137955427, 1.0]
 
STATIC LEARNING_PHASE = 1
[0.01744014158844948, 1.0]

模型收斂得更快,改變learning_phase也不再影響模型的準確率了,因爲現在BN都會使用訓練集的均值和方差進行歸一化。

2.5 這個修復在真實數據集上表現如何

我們用Keras預訓練的ResNet50,在CIFAR10上開展實驗,只訓練分類層10個epoch,以及139層以後5個epoch。沒有用補丁的時候準確率爲87.44%,用了之後準確率爲92.36%,提升了5個點。

2.6 其他層是否也要做類似的修復呢?

Dropout在訓練時和測試時的表現也不同,但Dropout是用來避免過擬合的,如果在訓練時也將其凍結在測試模式,Dropout就沒用了,所以Dropout被frozen時,我們還是讓它保持能夠隨機丟棄單元的現狀吧。

參考文獻:
https://zhuanlan.zhihu.com/p/56225304
http://blog.datumbox.com/the-batch-normalization-layer-of-keras-is-broken/

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章