keras 冻结指定层(设置为不可训练/可训练)

Fine-tune 冻结指定层

fine-tune 某些公开模型时,由于我们自己的任务类别数会与公开模型的类别数不同,因此通常的做法是将模型的最后一层的改变,并且固定全连接层之前的模型权重重新训练

  • 如下面的例子,我们使用inceptionV3 模型作为base model,后面接上1*1的卷积层和全连接层。为了复用InceptionV3模型的参数,应当设置base model不可训练及冻结base model的参数。
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential, load_model
from keras.layers import Activation, Dropout, Flatten, Reshape, Dense, Concatenate, GlobalMaxPooling2D
from keras.layers import BatchNormalization, Input, Conv2D, Lambda, Average
from keras.applications.inception_v3 import InceptionV3
from keras.callbacks import ModelCheckpoint
from keras import metrics
from keras.optimizers import Adam
from keras import backend as K
import keras
from keras.models import Model

    
def create_model(n_out):
    input_shape=(WINDOW_SIZE,WINDOW_SIZE, IMAGE_CHANNELS)
    input_tensor = Input(shape=(WINDOW_SIZE, WINDOW_SIZE, IMAGE_CHANNELS))
    base_model = InceptionV3(include_top=False,
                             weights='imagenet',
                             input_shape=input_shape
                             #input_shape=(WINDOW_SIZE, WINDOW_SIZE, IMAGE_CHANNELS)
                            )
    bn = BatchNormalization()(input_tensor)
    x = base_model(bn)
    x = Conv2D(32, kernel_size=(1,1), activation='relu')(x)
    x = Flatten()(x)
    x = Dropout(0.5)(x)
    x = Dense(1024, activation='relu')(x)
    x = Dropout(0.5)(x)
    output = Dense(n_out, activation='sigmoid')(x)
    model = Model(input_tensor, output)
    
    return model

# warm up model
model = create_model(n_out=NUM_CLASSES)
  • 首先设置所有层的trainable属性为False,然后设置最后的6层trainable为True
for layer in model.layers:
    layer.trainable = False
# 或者使用如下方法冻结所有层
# model.trainable = False 
model.layers[-1].trainable = True
model.layers[-2].trainable = True
model.layers[-3].trainable = True
model.layers[-4].trainable = True
model.layers[-5].trainable = True
model.layers[-6].trainable = True
  • 查看哪些层可训练或者不可训练
# 可训练层
for x in model.trainable_weights:
    print(x.name)
print('\n')

# 不可训练层
for x in model.non_trainable_weights:
    print(x.name)
print('\n')

进阶参考:[ Keras ] ——基本使用:(2) fine-tune+冻结层+抽取模型某一层输出
参考:
https://keras.io/getting-started/faq/#how-can-i-freeze-keras-layers

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章