keras課程5-callbacks

1. 簡述

callbacks中文含義爲回調。該模塊主要是一些回調函數,即在模型訓練的某個時刻執行該回調函數。如early stopping,模型訓練過程中,發現n次效果都不再提升,即停止訓練。

2. EarlyStopping

早停,第一章簡述已經提過,其參數如下:
在這裏插入圖片描述
案例:

callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss='mse')
history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
                    epochs=10, batch_size=1, callbacks=[callback],
                    verbose=0)

備註:這裏的monitor是被觀察指標,這個指標都有哪些取值,可參照第6章LambdaCallback:它是由每個epoch中的最後一個進度條決定的,最後一個進度條顯示了訓練集和驗證集的loss(compile中的loss)及metrics(compile中的metrics),monitor可選擇其中任何一個:
在這裏插入圖片描述

3. ModelCheckpoint

在模型訓練的時候,我們需要保存模型,ModelCheckpoint即確定保存模型的時刻及保存符合某條件的模型。
在這裏插入圖片描述
注意上述參數中,只有save_freq爲"epoch"時,filepath才能寫成"model_{epoch:02d}-{val_acc:.2f}.hdf5"或者weight.{epoch:02d}-{val_acc:.2f}.hdf5等形式。
案例:

EPOCHS = 10
checkpoint_filepath = '/tmp/model.h5'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_acc',
    mode='max',
    save_best_only=True)

# Model weights are saved at the end of every epoch, if it's the best seen
# so far.
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])

# The model weights (that are considered the best) are loaded into the model.
model.load_weights(checkpoint_filepath)

4. History

該回調函數,自動應用於每個keras模型的fit方法
如第二章講述EarlyStopping時的案例。直接history=model.fit(…),也不需要from tensorflow.keras.callbacks import History。history記錄了訓練過程中的結果,如每個epoch的train_loss/train_acc/val_loss/val_acc等。

5. LearningRateSchedule

該函數爲動態調整學習率(每個epoch的初始學習率不同),其默認表達如下:

tf.keras.callbacks.LearningRateScheduler(
    schedule, verbose=0
)

可以看到一個重要參數schedule,該參數爲自定義函數,輸入爲epoch,輸入爲浮點型的學習率。通過以下案例說明:

def scheduler(epoch):
  if epoch < 10:  
    return 0.001   
  else:
    return 0.001 * tf.math.exp(0.1 * (10 - epoch))

callback = tf.keras.callbacks.LearningRateScheduler(scheduler)
model.fit(data, labels, epochs=100, callbacks=[callback],
          validation_data=(val_data, val_labels))

以下爲實際當中的一個例子

def schedule(epoch_idx):
    if epoch_idx < 10:
        return 0.1 / 10 * (epoch_idx+1)
    else:
        t = (epoch_idx - 10) * math.pi / 90 
        return  1/2 * (1 + math.cos(t)) * 0.1


scheduler = LearningRateScheduler(schedule=schedule)
model = bi_lstm_attention(max_len, max_cnt, embed_size, embedding_matrix)  # 模型實例化
model.compile(loss='categorical_crossentropy',
                      optimizer=tf.keras.optimizers.SGD(lr=0.0, momentum=0.9, decay=0.0, 	   
                      nesterov=False),
                      metrics=['accuracy'])    # 使用了學習率調整,sgd的lr可設置爲0
history = model.fit(X_train, X_train_label,
                  validation_data=(X_val, X_val_label),
                  epochs=100, batch_size=64,
                  shuffle=True,
                  callbacks=[early_stopping, scheduler],
                 )

學習率調整的方法,見博客:https://www.cnblogs.com/xym4869/p/11654611.html,keras中的具體實現後續補充

6. LambdaCallback

其結構如下:

tf.keras.callbacks.LambdaCallback(
    on_epoch_begin=None, on_epoch_end=None, on_batch_begin=None,  
    on_batch_end=None,
    on_train_begin=None, on_train_end=None, **kwargs
)

每一個參數都是一個lambda的函數,lambda函數必須有兩個參數或一個參數,具體要求如下:
在這裏插入圖片描述
舉例:

print_callback = LambdaCallback(on_batch_begin=lambda batch,logs: print(batch, logs),
                                on_batch_end=lambda batch, logs:print('=', batch, logs)
                               )

on_batch_begin表示在每個batch的開始,on_batch_end表示在每個batch的結尾。其某案例結果如下所示:
在這裏插入圖片描述
上圖中進度條顯示的是loss和acc,其中loss由compile的loss決定,acc由compile的metrics決定,如果 metrics=[‘Recall’],則結果如下圖:
在這裏插入圖片描述
官網另一個例子如下:

# Stream the epoch loss to a file in JSON format. The file content
# is not well-formed JSON but rather has a JSON object per line.
import json
json_log = open('loss_log.json', mode='wt', buffering=1)
json_logging_callback = LambdaCallback(
    on_epoch_end=lambda epoch, logs: json_log.write(
        json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '\n'),   # 每個epoch結束後將epoch、loss寫入文件
    on_train_end=lambda logs: json_log.close()  # 訓練結束後關閉句柄
)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章