Keras回調函數使用

回調函數使用

回調函數是一個函數的合集,會在訓練的階段中所使用。你可以使用回調函數來查看訓練模型的內在狀態和統計。你可以傳遞一個列表的回調函數(作爲 callbacks 關鍵字參數)到 SequentialModel 類型的 .fit() 方法。在訓練時,相應的回調函數的方法就會被在各自的階段被調用。

模型生成並保存以後每次只需要調用保存好的模型,不需要重新訓練

Callback

keras.callbacks.Callback()

用來組建新的回調函數的抽象基類。

屬性

  • params: 字典。訓練參數, (例如,verbosity, batch size, number of epochs...)。
  • model: keras.models.Model 的實例。 指代被訓練模型。

被回調函數作爲參數的 logs 字典,它會含有於當前批量或訓練輪相關數據的鍵。

目前,Sequential 模型類的 .fit() 方法會在傳入到回調函數的 logs 裏面包含以下的數據:

  • on_epoch_end: 包括 accloss 的日誌, 也可以選擇性的包括 val_loss(如果在 fit 中啓用驗證),和 val_acc(如果啓用驗證和監測精確值)。
  • on_batch_begin: 包括 size 的日誌,在當前批量內的樣本數量。
  • on_batch_end: 包括 loss 的日誌,也可以選擇性的包括 acc(如果啓用監測精確值)。

 

ModelCheckpoint

keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1)

在每個訓練期之後保存模型。

filepath 可以包括命名格式選項,可以由 epoch 的值和 logs 的鍵(由 on_epoch_end 參數傳遞)來填充。

例如:如果 filepathweights.{epoch:02d}-{val_loss:.2f}.hdf5, 那麼模型被保存的的文件名就會有訓練輪數和驗證損失。

參數

  • filepath: 字符串,保存模型的路徑。
  • monitor: 被監測的數據。
  • verbose: 詳細信息模式,0 或者 1 。
  • save_best_only: 如果 save_best_only=True, 被監測數據的最佳模型就不會被覆蓋。
  • mode: {auto, min, max} 的其中之一。 如果 save_best_only=True,那麼是否覆蓋保存文件的決定就取決於被監測數據的最大或者最小值。 對於 val_acc,模式就會是 max,而對於 val_loss,模式就需要是 min,等等。 在 auto 模式中,方向會自動從被監測的數據的名字中判斷出來。
  • save_weights_only: 如果 True,那麼只有模型的權重會被保存 (model.save_weights(filepath)), 否則的話,整個模型會被保存 (model.save(filepath))。
  • period: 每個檢查點之間的間隔(訓練輪數)。

CSVLogger

keras.callbacks.CSVLogger(filename, separator=',', append=False)

把訓練輪結果數據流到 csv 文件的回調函數。

支持所有可以被作爲字符串表示的值,包括 1D 可迭代數據,例如,np.ndarray。

例子

csv_logger = CSVLogger('training.log')
model.fit(X_train, Y_train, callbacks=[csv_logger])

參數

  • filename: csv 文件的文件名,例如 'run/log.csv'。
  • separator: 用來隔離 csv 文件中元素的字符串。
  • append: True:如果文件存在則增加(可以被用於繼續訓練)。False:覆蓋存在的文件。

EarlyStopping

keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=0, verbose=0, mode='auto', baseline=None, restore_best_weights=False)

當被監測的數量不再提升,則停止訓練。

參數

  • monitor: 被監測的數據。
  • min_delta: 在被監測的數據中被認爲是提升的最小變化, 例如,小於 min_delta 的絕對變化會被認爲沒有提升。
  • patience: 沒有進步的訓練輪數,在這之後訓練就會被停止。
  • verbose: 詳細信息模式。
  • mode: {auto, min, max} 其中之一。 在 min 模式中, 當被監測的數據停止下降,訓練就會停止;在 max 模式中,當被監測的數據停止上升,訓練就會停止;在 auto 模式中,方向會自動從被監測的數據的名字中判斷出來。
  • baseline: 要監控的數量的基準值。 如果模型沒有顯示基準的改善,訓練將停止。
  • restore_best_weights: 是否從具有監測數量的最佳值的時期恢復模型權重。 如果爲 False,則使用在訓練的最後一步獲得的模型權重。

ReduceLROnPlateau

keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, verbose=0, mode='auto', min_delta=0.0001, cooldown=0, min_lr=0)

當標準評估停止提升時,降低學習速率。

當學習停止時,模型總是會受益於降低 2-10 倍的學習速率。 這個回調函數監測一個數據並且當這個數據在一定「有耐心」的訓練輪之後還沒有進步, 那麼學習速率就會被降低。

例子

reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2,
                              patience=5, min_lr=0.001)
model.fit(X_train, Y_train, callbacks=[reduce_lr])

參數

  • monitor: 被監測的數據。
  • factor: 學習速率被降低的因數。新的學習速率 = 學習速率 * 因數
  • patience: 沒有進步的訓練輪數,在這之後訓練速率會被降低。
  • verbose: 整數。0:安靜,1:更新信息。
  • mode: {auto, min, max} 其中之一。如果是 min 模式,學習速率會被降低如果被監測的數據已經停止下降; 在 max 模式,學習塑料會被降低如果被監測的數據已經停止上升; 在 auto 模式,方向會被從被監測的數據中自動推斷出來。
  • min_delta: 對於測量新的最優化的閥值,只關注巨大的改變。
  • cooldown: 在學習速率被降低之後,重新恢復正常操作之前等待的訓練輪數量。
  • min_lr: 學習速率的下邊界。
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章