文章目录
训练神经网络时,控制训练的周期很关键。如果过早停止训练,可能导致欠拟合;如果训练周期过长,可能导致过拟合,从而导致泛化能力很差。一种解决方案是在训练数据集上进行训练,在验证数据集的性能开始下降时停止训练。这种简单,有效且广泛使用的训练神经网络的方法称为及时停止(Early Stopping)。
1. 如何使用及时停止
及时停止要求网络配置为处于受限状态,这意味着网络具有比问题所需容量更多的容量。
在训练网络时,使用比通常更多的训练周期,以使网络有足够的周期拟合,然后设置及时停止使在训练周期合适时停止训练。使用及时停止有三个要素:
- 要监视的指标;
- 触发停止条件;
- 要保存的模型;
1.1 要监视的指标
训练神经网络时,通常会从数据集中拆分出一个子集(例如30%)作为验证数据集,用于在训练过程中监视模型的性能。该验证集不用于训练模型。通常也可以使用验证数据集上的损失作为监视指标。
一般来说,在回归问题中,使用验证集上的预测误差作为指标;在分类问题中,使用验证集上的准确率作为指标。
训练数据集上模型的损失也将作为训练过程的一部分提供,还可以在训练数据集上计算和监视其他指标。
在每个时期结束时在验证集上评估模型的性能,这会增加训练期间的额外计算成本。可以通过不那么频繁地评估模型(例如每2、5或10个训练时期)来减少这种情况。
1.2 触发停止条件
选择了模型的监视指标之后,就要设置停止训练的触发器。
触发器将使用监视的性能指标来决定何时停止训练。这通常是模型在验证集上的性能,例如验证损失(val_loss)。
在最简单的情况下,与先前训练时期(例如,val_loss增加)相比,验证数据集的性能下降后,训练就会立即停止。
在实践中可能需要更详细的触发器。这是因为神经网络的训练是随机的,并且可能夹杂很多噪声。绘制验证损失和验证准确率曲线可以看出,模型的性能可能会多次上升和下降。也就是说,出现第一个过度拟合迹象时就停止训练是不妥当的,因为实际的验证集上的误差曲线存在多个局部极小值。
一些更详细的触发器可能包括:
- 给定周期内,指标没有变化;
- 给定周期内,指标的绝对变化;
- 给定周期内,指标的性能下降;
- 给定周期内,指标的平均变化;
1.3 要保存的模型
在停止训练时,该模型的泛化误差比先前时期的模型大一些。
因此,需要考虑合适保存模型或者说如何保存性能最好的模型,亦即保存训练过程中哪个模型的权重。这取决于为停止训练过程而选择的触发器。例如,如果触发是性能从一个时期到下一个时期的降低,那么将优先考虑模型在先前时期的权重。如果要求触发器在固定时期内观察到性能下降,则将首选触发器周期开始时的模型。
一个常用的方法是:每当验证集上的误差改善时,就保存此时模型权重文件。当终止训练时,得到的模型权重就是最佳模型的权重。
2. 及时停止使用技巧
2.1 适用范围
几乎所有的神经网络都需要设置及时停止。
2.2 通过绘制曲线观察
在使用及时停止之前,可能需要设置较长的训练周期来让模型进行拟合,并在训练和验证数据集上监视模型的性能。实时或长期运行结束后绘制模型的性能,通过观察监视指标的变化情况,有助于选择提前停止的触发器。
2.3 监视指标选择
损失是在训练过程中进行监控并触发提前停止的简单指标。
问题在于,损失并不能总反映出最适合业务场景需求的模型。最好选择一种性能指标进行监视,以最好地定义模型的性能。
2.4 训练周期选择
及时停止的一个问题是模型没有利用所有可用的训练数据。
这可能需要避免过拟合并在所有可能的数据上进行训练,尤其是在训练数据量非常有限的情况下。
推荐的方法是将训练时期的数量视为超参数,使用k-折交叉验证对不同值的范围进行网格搜索。可以固定训练周期的数量,并在所有可用数据上拟合最终模型。
及时停止过程可以重复多次。可以记录停止训练的周期数。然后,在将最终模型拟合到所有可用的训练数据上时,可以使用及时停止的所有周期数的平均值。
每次运行早期停止时,都可以使用不同的训练集划分为训练和验证步骤来执行此过程。一种替代方法可能是使用验证数据集及时停止,然后通过对所提供的验证集进行进一步训练来更新最终模型。
2.5 过拟合验证
多次重复及时停止过程可能会导致模型过度拟合验证数据集,这和过度拟合训练数据集一样容易。一种方法是在选择了模型的所有其他超参数后才使用及时停止。
另一种策略可能是每次使用及时停止时,都将训练数据集分为不同的训练集和验证集。
3. TensorFlow API
3.1 拆分验证集
在 tensorflow.keras 中,有两种方式可以设置验证集:
...
model.fit(train_X, train_y, validation_data=(val_x, val_y))
...
model.fit(train_X, train_y, validation_split=0.3)
3.2 设置监视指标
如果在 model.fit()
API中设置了 validation_data
或 validation_split
参数,则会返回验证数据集上的损失,名称为 val_loss
。
可以在编译模型时通过 model.compile
函数的 metrics
参数来指定它们。此参数使用Python列表传入,例如 mse
表示均方误差,precision
表示精度。常用监视指标:
...
model.compile(..., metrics=['accuracy'])
如果在训练时监视其它指标,也可通过相同的名称提供给 metrics
参数,如 val_accuracy
表示验证集上的准确率。mse
表示训练集上的均方误差,val_mse
表示验证集上的均方误差。
3.3 EarlyStopping API
tf.keras.callbacks.EarlyStopping(
monitor='val_loss', min_delta=0, patience=0, verbose=0, mode='auto',
baseline=None, restore_best_weights=False
)
参数说明:
monitor
:监视指标。min_delta
:监视指标的最小变化(有改进的定义);即绝对变化小于min_delta
,不视为改进。patience
:没有改进的时期数,之后训练将停止。verbose
:在控制台打印信息的级别;verbose设置为1时,训练停止时返回训练周期数。mode
:{"auto", "min", "max"}
中选其一。在min
模式下,当监视的数量停止减少时,训练将停止;在max 模式下,当监视的数量停止增加时,它将停止;在auto
模式下,将根据监视数量的名称自动推断出变化方向;baseline
:监视数量的基准值。如果模型没有超过基线的改善,则停止训练。restore_best_weights
:是否使用该训练周期中监视指标最佳模型所对应的权重。如果为False
,则使用在训练的最后一步获得的模型权重。
3.4 ModelCheckPoint API
tf.keras.callbacks.ModelCheckpoint(
filepath, monitor='val_loss', verbose=0, save_best_only=False,
save_weights_only=False, mode='auto', save_freq='epoch', **kwargs
)
用于以一定间隔保存模型或权重(在检查点文件中),之后可以重新加载模型或权重以从保存的状态继续训练。
filepath
:字符串,保存模型文件的路径。filepath可以包含命名的格式选项,这会传递给该API的方法on_epoch_end
中。例如:如果filepath为weights.{epoch:02d}-{val_loss:.2f}.hdf5
,则模型将以epoch和验证损失作为文件名保存。monitor
:监视指标。verbose
:在控制台打印信息的级别,0或1。save_best_only
:如果为save_best_only=True
,则最佳模型不会被覆盖。如果filepath不包含格式设置选项,原名称会被覆盖。save_weights_only
:如果为True,则仅保存模型的权重(model.save_weights(filepath)
),否则保存完整的模型(model.save(filepath)
)。save_freq
:'epoch'
或整数。使用'epoch'
时,回调函数会在每个时期后保存模型。使用整数时,回调将在许多批次结束时保存模型。请注意,如果保存未与时间段保持一致,则受监视的指标可能会不太可靠(它可能只反映1个批次,因为每个epoch都会重置该指标)。默认为’epoch’。
3.5 实例
# 定义模型
model = Sequential()
model.add(Dense(500, input_dim=2, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
# 编译模型
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# 定义回调
callbacks_set = [EarlyStopping(monitor='val_loss', mode='min', verbose=1),
ModelCheckpoint('best_model.h5', monitor='val_accuracy', mode='max', verbose=1, save_best_only=True)]
# 训练模型
# 及时停止,同时保存验证集上性能最好(验证损失最低)的模型
history = model.fit(trainX, trainy, validation_data=(testX, testy), epochs=4000, verbose=0,
callbacks=callbacks_set)
tensorboard log 等等其它回调同理。
https://machinelearningmastery.com/early-stopping-to-avoid-overtraining-neural-network-models/
https://machinelearningmastery.com/how-to-stop-training-deep-neural-networks-at-the-right-time-using-early-stopping/
https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/EarlyStopping?hl=en
https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/ModelCheckpoint?hl=en