Keras中的回调函数Callbacks详解

介绍

回调函数是一组在训练的特定阶段被调用的函数集,你可以使用回调函数来观察训练过程中网络内部的状态和统计信息。通过传递回调函数列表到模型的.fit()中,即可在给定的训练阶段调用该函数集中的函数。
虽然我们称之为回调“函数”,但事实上Keras的回调函数是一个类

keras.callbacks.Callback()

是回调函数的抽象类,定义新的回调函数必须继承自该类
类属性

  • params:字典,训练参数集(如信息显示方法verbosity,batch大小,epoch数)

  • model:keras.models.Model对象,为正在训练的模型的引用

回调函数以字典logs为参数,该字典包含了一系列与当前batch或epoch相关的信息。

目前,模型的.fit()中有下列参数会被记录到logs中:

  • 在每个epoch的结尾处(on_epoch_end),logs将包含训练的正确率和误差,acc和loss,如果指定了验证集,还会包含验证集正确率和误差val_acc)和val_loss,val_acc还额外需要在.compile中启用metrics=[‘accuracy’]。

  • 在每个batch的开始处(on_batch_begin):logs包含size,即当前batch的样本数

  • 在每个batch的结尾处(on_batch_end):logs包含loss,若启用accuracy则还包含acc

from keras.callbacks import Callback

callbacks来控制正在训练的模型

最开始训练过程是先训练一遍,然后得到一个验证集的正确率变化趋势,从而知道最佳的epoch,设置最佳epoch,再训练一遍得到最终结果,十分浪费时间!!!

节省时间的一个办法是在验证集准确率不再上升的时候,终止训练。keras中的callback可以帮我们做到这一点。
callback是一个obj类型的,他可以让模型去拟合,也常在各个点被调用。它存储模型的状态,能够采取措施打断训练,保存模型,加载不同的权重,或者替代模型状态。
callbacks可以用来做这些事情:

  • 模型断点续训:保存当前模型的所有权重
  • 提早结束:当模型的损失不再下降的时候就终止训练,当然,会保存最优的模型。
  • 动态调整训练时的参数,比如优化的学习率。
  • 等等

EarlyStopping 早停止

monitor为选择的检测指标,这里选择检测’acc’识别率为指标,

  • patience当连续多少个epochs时验证集精度不再变好终止训练,这里选择了1

ModelCheckpoint存储最优的模型

filepath为我们存储的位置和模型名称,以.h5为后缀,

  • monitor为检测的指标,这里我们检测验证集里面的成功率,
  • save_best_only代表我们只保存最优的训练结果。
  • validation_data就是给定的验证集数据。
	import keras
	# Callbacks are passed to the model fit the `callbacks` argument in `fit`,
	# which takes a list of callbacks. You can pass any number of callbacks.
	callbacks_list = [
  	# This callback will interrupt training when we have stopped improving
 	 keras.callbacks.EarlyStopping(
  	# This callback will monitor the validation accuracy of the model
  	monitor='acc',
 	 # Training will be interrupted when the accuracy
 	 # has stopped improving for *more* than 1 epochs (i.e. 2 epochs)
 	 patience=1,
  	),
  	# This callback will save the current weights after every epoch
  	keras.callbacks.ModelCheckpoint(
 	 filepath='my_model.h5', # Path to the destination model file
 	 # The two arguments below mean that we will not overwrite the
  	# model file unless `val_loss` has improved, which
  	# allows us to keep the best model every seen during training.
  	monitor='val_loss',
  	save_best_only=True,
  	)
	]
	# Since we monitor `acc`, it should be part of the metrics of the model.
	model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['acc'])
	# Note that since the callback will be monitor validation accuracy,
	# we need to pass some `validation_data` to our call to `fit`.
	model.fit(x, y,
 		 epochs=10,
  		 batch_size=32,
 	 	 callbacks=callbacks_list,
 		 validation_data=(x_val, y_val))

ReduceLROnPlateau 学习率减少

	callbacks_list = [
  	keras.callbacks.ReduceLROnPlateau(
 	 # This callback will monitor the validation loss of the model
 	 monitor='val_loss',
  	 # It will divide the learning by 10 when it gets triggered
  	 factor=0.1,
 	 # It will get triggered after the validation loss has stopped improving
	 # for at least 10 epochs
 	 patience=10,
    )
	]# Note that since the callback will be monitor validation loss,
	# we need to pass some `validation_data` to our call to `fit`.
	model.fit(x, y,
  			epochs=30,
  			batch_size=32,
  			callbacks=callbacks_list,
 			validation_data=(x_val, y_val))

  • patience 如果连续10个批次,val_loss不再上升,就把学习率弄到原来的0.1倍

TensorBoard 可视化

keras.callbacks.TensorBoard(log_dir='./logs', histogram_freq=0, write_graph=True, write_images=False, embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None)

TensorBoard是TensorFlow提供的可视化工具,该回调函数将日志信息写入TensorBorad,使得你可以动态的观察训练和测试指标的图像以及不同层的激活值直方图。

如果已经通过pip安装了TensorFlow,我们可通过下面的命令启动TensorBoard:

tensorboard --logdir=/full_path_to_your_logs
  • log_dir:保存日志文件的地址,该文件将被TensorBoard解析以用于可视化

LambdaCallback

keras.callbacks.LambdaCallback(on_epoch_begin=None, on_epoch_end=None, on_batch_begin=None, on_batch_end=None, on_train_begin=None, on_train_end=None)

用于创建简单的callback的callback类

该callback的匿名函数将会在适当的时候调用,注意,该回调函数假定了一些位置参数:

  • on_epoch_begin和on_epoch_end假定输入的参数是epoch, logs.
  • on_batch_begin和on_batch_end假定输入的参数是batch, logs
  • on_train_begin和on_train_end假定输入的参数是logs

参数

  • on_epoch_begin: 在每个epoch开始时调用
  • on_epoch_end: 在每个epoch结束时调用
  • on_batch_begin: 在每个batch开始时调用
  • on_batch_end: 在每个batch结束时调用
  • on_train_begin: 在训练开始时调用
  • on_train_end: 在训练结束时调用
# Print the batch number at the beginning of every batch.
batch_print_callback = LambdaCallback(
    on_batch_begin=lambda batch,logs: print(batch))

# Plot the loss after every epoch.
import numpy as np
import matplotlib.pyplot as plt
plot_loss_callback = LambdaCallback(
    on_epoch_end=lambda epoch, logs: plt.plot(np.arange(epoch),
                      logs['loss']))

# Terminate some processes after having finished model training.
processes = ...
cleanup_callback = LambdaCallback(
    on_train_end=lambda logs: [
    p.terminate() for p in processes if p.is_alive()])

model.fit(...,
      callbacks=[batch_print_callback,
         plot_loss_callback,
         cleanup_callback])

自定义Callback

如果内置的callback操作还满足不了需求,可以通过继承keras.callbacks.Callback编写自己的回调函数,回调函数通过类成员self.model访问模型,该成员是模型的一个引用。

简单的保存每个batch的loss的回调函数

from keras.callbacks import Callback
class LossHistory(Callback):
	def on_train_begin(self, logs={}):
		self.losses = []
	def on_batch_end(self, batch, logs={}):
		self.losses.append(logs.get('loss))

将激活值以数组的形式存进磁盘

class ActivationLogger(keras.callbacks.Callback):
  	def set_model(self, model):
  		# This method is called by the parent model before training, 
  		# to inform the callback of what model will be calling it
  		self.model = model
  		layer_outputs = [layer.output for layer in model.layers]
  		# This is a model instance that returns the activations of every layer
  		self.activations_model = keras.models.Model(model.input, layer_outputs)
  		def on_epoch_end(self, epoch, logs=None):
 		 	if self.validation_data is None:
  				raise RuntimeError('Requires validation_data.')
 			 # Obtain first input sample of the validation data
 			validation_sample = self.validation_data[0][0:1]
  			activations = self.activations_model.predict(validation_sample)
  			# Save arrays to disk
 			 f = open('activations_at_epoch_' + str(epoch) + '.npz', 'w')
 			 np.savez(f, activations)
 			 f.close()

参考文档:
https://keras-cn.readthedocs.io/en/latest/other/callbacks/

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章