Keras中的回調函數Callbacks詳解

介紹

回調函數是一組在訓練的特定階段被調用的函數集,你可以使用回調函數來觀察訓練過程中網絡內部的狀態和統計信息。通過傳遞迴調函數列表到模型的.fit()中,即可在給定的訓練階段調用該函數集中的函數。
雖然我們稱之爲回調“函數”,但事實上Keras的回調函數是一個類

keras.callbacks.Callback()

是回調函數的抽象類,定義新的回調函數必須繼承自該類
類屬性

  • params:字典,訓練參數集(如信息顯示方法verbosity,batch大小,epoch數)

  • model:keras.models.Model對象,爲正在訓練的模型的引用

回調函數以字典logs爲參數,該字典包含了一系列與當前batch或epoch相關的信息。

目前,模型的.fit()中有下列參數會被記錄到logs中:

  • 在每個epoch的結尾處(on_epoch_end),logs將包含訓練的正確率和誤差,acc和loss,如果指定了驗證集,還會包含驗證集正確率和誤差val_acc)和val_loss,val_acc還額外需要在.compile中啓用metrics=[‘accuracy’]。

  • 在每個batch的開始處(on_batch_begin):logs包含size,即當前batch的樣本數

  • 在每個batch的結尾處(on_batch_end):logs包含loss,若啓用accuracy則還包含acc

from keras.callbacks import Callback

callbacks來控制正在訓練的模型

最開始訓練過程是先訓練一遍,然後得到一個驗證集的正確率變化趨勢,從而知道最佳的epoch,設置最佳epoch,再訓練一遍得到最終結果,十分浪費時間!!!

節省時間的一個辦法是在驗證集準確率不再上升的時候,終止訓練。keras中的callback可以幫我們做到這一點。
callback是一個obj類型的,他可以讓模型去擬合,也常在各個點被調用。它存儲模型的狀態,能夠採取措施打斷訓練,保存模型,加載不同的權重,或者替代模型狀態。
callbacks可以用來做這些事情:

  • 模型斷點續訓:保存當前模型的所有權重
  • 提早結束:當模型的損失不再下降的時候就終止訓練,當然,會保存最優的模型。
  • 動態調整訓練時的參數,比如優化的學習率。
  • 等等

EarlyStopping 早停止

monitor爲選擇的檢測指標,這裏選擇檢測’acc’識別率爲指標,

  • patience當連續多少個epochs時驗證集精度不再變好終止訓練,這裏選擇了1

ModelCheckpoint存儲最優的模型

filepath爲我們存儲的位置和模型名稱,以.h5爲後綴,

  • monitor爲檢測的指標,這裏我們檢測驗證集裏面的成功率,
  • save_best_only代表我們只保存最優的訓練結果。
  • validation_data就是給定的驗證集數據。
	import keras
	# Callbacks are passed to the model fit the `callbacks` argument in `fit`,
	# which takes a list of callbacks. You can pass any number of callbacks.
	callbacks_list = [
  	# This callback will interrupt training when we have stopped improving
 	 keras.callbacks.EarlyStopping(
  	# This callback will monitor the validation accuracy of the model
  	monitor='acc',
 	 # Training will be interrupted when the accuracy
 	 # has stopped improving for *more* than 1 epochs (i.e. 2 epochs)
 	 patience=1,
  	),
  	# This callback will save the current weights after every epoch
  	keras.callbacks.ModelCheckpoint(
 	 filepath='my_model.h5', # Path to the destination model file
 	 # The two arguments below mean that we will not overwrite the
  	# model file unless `val_loss` has improved, which
  	# allows us to keep the best model every seen during training.
  	monitor='val_loss',
  	save_best_only=True,
  	)
	]
	# Since we monitor `acc`, it should be part of the metrics of the model.
	model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['acc'])
	# Note that since the callback will be monitor validation accuracy,
	# we need to pass some `validation_data` to our call to `fit`.
	model.fit(x, y,
 		 epochs=10,
  		 batch_size=32,
 	 	 callbacks=callbacks_list,
 		 validation_data=(x_val, y_val))

ReduceLROnPlateau 學習率減少

	callbacks_list = [
  	keras.callbacks.ReduceLROnPlateau(
 	 # This callback will monitor the validation loss of the model
 	 monitor='val_loss',
  	 # It will divide the learning by 10 when it gets triggered
  	 factor=0.1,
 	 # It will get triggered after the validation loss has stopped improving
	 # for at least 10 epochs
 	 patience=10,
    )
	]# Note that since the callback will be monitor validation loss,
	# we need to pass some `validation_data` to our call to `fit`.
	model.fit(x, y,
  			epochs=30,
  			batch_size=32,
  			callbacks=callbacks_list,
 			validation_data=(x_val, y_val))

  • patience 如果連續10個批次,val_loss不再上升,就把學習率弄到原來的0.1倍

TensorBoard 可視化

keras.callbacks.TensorBoard(log_dir='./logs', histogram_freq=0, write_graph=True, write_images=False, embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None)

TensorBoard是TensorFlow提供的可視化工具,該回調函數將日誌信息寫入TensorBorad,使得你可以動態的觀察訓練和測試指標的圖像以及不同層的激活值直方圖。

如果已經通過pip安裝了TensorFlow,我們可通過下面的命令啓動TensorBoard:

tensorboard --logdir=/full_path_to_your_logs
  • log_dir:保存日誌文件的地址,該文件將被TensorBoard解析以用於可視化

LambdaCallback

keras.callbacks.LambdaCallback(on_epoch_begin=None, on_epoch_end=None, on_batch_begin=None, on_batch_end=None, on_train_begin=None, on_train_end=None)

用於創建簡單的callback的callback類

該callback的匿名函數將會在適當的時候調用,注意,該回調函數假定了一些位置參數:

  • on_epoch_begin和on_epoch_end假定輸入的參數是epoch, logs.
  • on_batch_begin和on_batch_end假定輸入的參數是batch, logs
  • on_train_begin和on_train_end假定輸入的參數是logs

參數

  • on_epoch_begin: 在每個epoch開始時調用
  • on_epoch_end: 在每個epoch結束時調用
  • on_batch_begin: 在每個batch開始時調用
  • on_batch_end: 在每個batch結束時調用
  • on_train_begin: 在訓練開始時調用
  • on_train_end: 在訓練結束時調用
# Print the batch number at the beginning of every batch.
batch_print_callback = LambdaCallback(
    on_batch_begin=lambda batch,logs: print(batch))

# Plot the loss after every epoch.
import numpy as np
import matplotlib.pyplot as plt
plot_loss_callback = LambdaCallback(
    on_epoch_end=lambda epoch, logs: plt.plot(np.arange(epoch),
                      logs['loss']))

# Terminate some processes after having finished model training.
processes = ...
cleanup_callback = LambdaCallback(
    on_train_end=lambda logs: [
    p.terminate() for p in processes if p.is_alive()])

model.fit(...,
      callbacks=[batch_print_callback,
         plot_loss_callback,
         cleanup_callback])

自定義Callback

如果內置的callback操作還滿足不了需求,可以通過繼承keras.callbacks.Callback編寫自己的回調函數,回調函數通過類成員self.model訪問模型,該成員是模型的一個引用。

簡單的保存每個batch的loss的回調函數

from keras.callbacks import Callback
class LossHistory(Callback):
	def on_train_begin(self, logs={}):
		self.losses = []
	def on_batch_end(self, batch, logs={}):
		self.losses.append(logs.get('loss))

將激活值以數組的形式存進磁盤

class ActivationLogger(keras.callbacks.Callback):
  	def set_model(self, model):
  		# This method is called by the parent model before training, 
  		# to inform the callback of what model will be calling it
  		self.model = model
  		layer_outputs = [layer.output for layer in model.layers]
  		# This is a model instance that returns the activations of every layer
  		self.activations_model = keras.models.Model(model.input, layer_outputs)
  		def on_epoch_end(self, epoch, logs=None):
 		 	if self.validation_data is None:
  				raise RuntimeError('Requires validation_data.')
 			 # Obtain first input sample of the validation data
 			validation_sample = self.validation_data[0][0:1]
  			activations = self.activations_model.predict(validation_sample)
  			# Save arrays to disk
 			 f = open('activations_at_epoch_' + str(epoch) + '.npz', 'w')
 			 np.savez(f, activations)
 			 f.close()

參考文檔:
https://keras-cn.readthedocs.io/en/latest/other/callbacks/

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章