Keras提供兩種學習率適應方法,可通過回調函數實現。
1. LearningRateScheduler
keras.callbacks.LearningRateScheduler(schedule)
該回調函數是學習率調度器.
參數
schedule:函數,該函數以epoch號爲參數(從0算起的整數),返回一個新學習率(浮點數)
代碼
- import keras.backend as K
- from keras.callbacks import LearningRateScheduler
-
- def scheduler(epoch):
- # 每隔100個epoch,學習率減小爲原來的1/10
- if epoch % 100 == 0 and epoch != 0:
- lr = K.get_value(model.optimizer.lr)
- K.set_value(model.optimizer.lr, lr * 0.1)
- print("lr changed to {}".format(lr * 0.1))
- return K.get_value(model.optimizer.lr)
-
- reduce_lr = LearningRateScheduler(scheduler)
- model.fit(train_x, train_y, batch_size=32, epochs=5, callbacks=[reduce_lr])
2. ReduceLROnPlateau
keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, verbose=0, mode='auto', epsilon=0.0001, cooldown=0, min_lr=0)
當評價指標不在提升時,減少學習率
當學習停滯時,減少2倍或10倍的學習率常常能獲得較好的效果。該回調函數檢測指標的情況,如果在patience個epoch中看不到模型性能提升,則減少學習率
參數
monitor:被監測的量
factor:每次減少學習率的因子,學習率將以lr = lr*factor的形式被減少
patience:當patience個epoch過去而模型性能不提升時,學習率減少的動作會被觸發
mode:‘auto’,‘min’,‘max’之一,在min模式下,如果檢測值觸發學習率減少。在max模式下,當檢測值不再上升則觸發學習率減少。
epsilon:閾值,用來確定是否進入檢測值的“平原區”
cooldown:學習率減少後,會經過cooldown個epoch才重新進行正常操作
min_lr:學習率的下限
代碼
- from keras.callbacks import ReduceLROnPlateau
- reduce_lr = ReduceLROnPlateau(monitor='val_loss', patience=10, mode='auto')
- model.fit(train_x, train_y, batch_size=32, epochs=5, validation_split=0.1, callbacks=[reduce_lr])
上述文章轉自於 z小白
本文在於如何利用keras尋找最優的學習率,根據Leslie N. Smith,一個簡單的方法是將學習率依次從小到大緩慢增加,每個學習率只用於學習一輪(這個按照數據集而異,可以根據學習的難度進行調整,我在step_per_epoch
進行了調整)
函數如下:
def scheduler(epoch):
# 每隔100個epoch,學習率減小爲原來的1/10
if epoch == 0:
# lr = K.get_value(model.optimizer.lr)
K.set_value(model.optimizer.lr, 1e-6)
print('The initial lr is {}'.format(1e-6))
if epoch % 1 == 0 and epoch != 0:
lr = K.get_value(model.optimizer.lr)
K.set_value(model.optimizer.lr, lr * 10)
print("lr changed to {}".format(lr * 10))
return K.get_value(model.optimizer.lr)
然後使用上面的LearningRateScheduler
方法,設置callback_object,主程序如下,不要忘了打印圖片奧。
...
reduce_lr = LearningRateScheduler(scheduler)
history = model.fit_generator(source_gen,\
epochs=6, steps_per_epoch=1.5*source_num//args.batch_size,\
validation_data=val_gen, validation_steps=5,\
callbacks=[reduce_lr])
val_loss = history.history['val_loss']
loss = [1e-6, 1e-5, 1e-4, 1e-3, 0.01, 0.1]
plt.subplot(121)
plt.plot(loss, val_loss)
plt.subplot(122)
plt.scatter(loss, val_loss)
plt.show()
...