keras 凍結指定層(設置爲不可訓練/可訓練)

Fine-tune 凍結指定層

fine-tune 某些公開模型時,由於我們自己的任務類別數會與公開模型的類別數不同,因此通常的做法是將模型的最後一層的改變,並且固定全連接層之前的模型權重重新訓練

  • 如下面的例子,我們使用inceptionV3 模型作爲base model,後面接上1*1的卷積層和全連接層。爲了複用InceptionV3模型的參數,應當設置base model不可訓練及凍結base model的參數。
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential, load_model
from keras.layers import Activation, Dropout, Flatten, Reshape, Dense, Concatenate, GlobalMaxPooling2D
from keras.layers import BatchNormalization, Input, Conv2D, Lambda, Average
from keras.applications.inception_v3 import InceptionV3
from keras.callbacks import ModelCheckpoint
from keras import metrics
from keras.optimizers import Adam
from keras import backend as K
import keras
from keras.models import Model

    
def create_model(n_out):
    input_shape=(WINDOW_SIZE,WINDOW_SIZE, IMAGE_CHANNELS)
    input_tensor = Input(shape=(WINDOW_SIZE, WINDOW_SIZE, IMAGE_CHANNELS))
    base_model = InceptionV3(include_top=False,
                             weights='imagenet',
                             input_shape=input_shape
                             #input_shape=(WINDOW_SIZE, WINDOW_SIZE, IMAGE_CHANNELS)
                            )
    bn = BatchNormalization()(input_tensor)
    x = base_model(bn)
    x = Conv2D(32, kernel_size=(1,1), activation='relu')(x)
    x = Flatten()(x)
    x = Dropout(0.5)(x)
    x = Dense(1024, activation='relu')(x)
    x = Dropout(0.5)(x)
    output = Dense(n_out, activation='sigmoid')(x)
    model = Model(input_tensor, output)
    
    return model

# warm up model
model = create_model(n_out=NUM_CLASSES)
  • 首先設置所有層的trainable屬性爲False,然後設置最後的6層trainable爲True
for layer in model.layers:
    layer.trainable = False
# 或者使用如下方法凍結所有層
# model.trainable = False 
model.layers[-1].trainable = True
model.layers[-2].trainable = True
model.layers[-3].trainable = True
model.layers[-4].trainable = True
model.layers[-5].trainable = True
model.layers[-6].trainable = True
  • 查看哪些層可訓練或者不可訓練
# 可訓練層
for x in model.trainable_weights:
    print(x.name)
print('\n')

# 不可訓練層
for x in model.non_trainable_weights:
    print(x.name)
print('\n')

進階參考:[ Keras ] ——基本使用:(2) fine-tune+凍結層+抽取模型某一層輸出
參考:
https://keras.io/getting-started/faq/#how-can-i-freeze-keras-layers

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章