【调参06】如何通过设置及时停止避免过拟合



训练神经网络时,控制训练的周期很关键。如果过早停止训练,可能导致欠拟合;如果训练周期过长,可能导致过拟合,从而导致泛化能力很差。一种解决方案是在训练数据集上进行训练,在验证数据集的性能开始下降时停止训练。这种简单,有效且广泛使用的训练神经网络的方法称为及时停止(Early Stopping)。

1. 如何使用及时停止

及时停止要求网络配置为处于受限状态,这意味着网络具有比问题所需容量更多的容量。

在训练网络时,使用比通常更多的训练周期,以使网络有足够的周期拟合,然后设置及时停止使在训练周期合适时停止训练。使用及时停止有三个要素:

  • 要监视的指标;
  • 触发停止条件;
  • 要保存的模型;

1.1 要监视的指标

训练神经网络时,通常会从数据集中拆分出一个子集(例如30%)作为验证数据集,用于在训练过程中监视模型的性能。该验证集不用于训练模型。通常也可以使用验证数据集上的损失作为监视指标。

一般来说,在回归问题中,使用验证集上的预测误差作为指标;在分类问题中,使用验证集上的准确率作为指标。

训练数据集上模型的损失也将作为训练过程的一部分提供,还可以在训练数据集上计算和监视其他指标。

在每个时期结束时在验证集上评估模型的性能,这会增加训练期间的额外计算成本。可以通过不那么频繁地评估模型(例如每2、5或10个训练时期)来减少这种情况。


1.2 触发停止条件

选择了模型的监视指标之后,就要设置停止训练的触发器。

触发器将使用监视的性能指标来决定何时停止训练。这通常是模型在验证集上的性能,例如验证损失(val_loss)。

在最简单的情况下,与先前训练时期(例如,val_loss增加)相比,验证数据集的性能下降后,训练就会立即停止。

在实践中可能需要更详细的触发器。这是因为神经网络的训练是随机的,并且可能夹杂很多噪声。绘制验证损失和验证准确率曲线可以看出,模型的性能可能会多次上升和下降。也就是说,出现第一个过度拟合迹象时就停止训练是不妥当的,因为实际的验证集上的误差曲线存在多个局部极小值。

一些更详细的触发器可能包括:

  • 给定周期内,指标没有变化;
  • 给定周期内,指标的绝对变化;
  • 给定周期内,指标的性能下降;
  • 给定周期内,指标的平均变化;

1.3 要保存的模型

在停止训练时,该模型的泛化误差比先前时期的模型大一些。

因此,需要考虑合适保存模型或者说如何保存性能最好的模型,亦即保存训练过程中哪个模型的权重。这取决于为停止训练过程而选择的触发器。例如,如果触发是性能从一个时期到下一个时期的降低,那么将优先考虑模型在先前时期的权重。如果要求触发器在固定时期内观察到性能下降,则将首选触发器周期开始时的模型。

一个常用的方法是:每当验证集上的误差改善时,就保存此时模型权重文件。当终止训练时,得到的模型权重就是最佳模型的权重。


2. 及时停止使用技巧

2.1 适用范围

几乎所有的神经网络都需要设置及时停止。


2.2 通过绘制曲线观察

在使用及时停止之前,可能需要设置较长的训练周期来让模型进行拟合,并在训练和验证数据集上监视模型的性能。实时或长期运行结束后绘制模型的性能,通过观察监视指标的变化情况,有助于选择提前停止的触发器。


2.3 监视指标选择

损失是在训练过程中进行监控并触发提前停止的简单指标。

问题在于,损失并不能总反映出最适合业务场景需求的模型。最好选择一种性能指标进行监视,以最好地定义模型的性能。


2.4 训练周期选择

及时停止的一个问题是模型没有利用所有可用的训练数据。

这可能需要避免过拟合并在所有可能的数据上进行训练,尤其是在训练数据量非常有限的情况下。

推荐的方法是将训练时期的数量视为超参数,使用k-折交叉验证对不同值的范围进行网格搜索。可以固定训练周期的数量,并在所有可用数据上拟合最终模型。

及时停止过程可以重复多次。可以记录停止训练的周期数。然后,在将最终模型拟合到所有可用的训练数据上时,可以使用及时停止的所有周期数的平均值。

每次运行早期停止时,都可以使用不同的训练集划分为训练和验证步骤来执行此过程。一种替代方法可能是使用验证数据集及时停止,然后通过对所提供的验证集进行进一步训练来更新最终模型。


2.5 过拟合验证

多次重复及时停止过程可能会导致模型过度拟合验证数据集,这和过度拟合训练数据集一样容易。一种方法是在选择了模型的所有其他超参数后才使用及时停止。

另一种策略可能是每次使用及时停止时,都将训练数据集分为不同的训练集和验证集。


3. TensorFlow API

3.1 拆分验证集

在 tensorflow.keras 中,有两种方式可以设置验证集:

...
model.fit(train_X, train_y, validation_data=(val_x, val_y))
...
model.fit(train_X, train_y, validation_split=0.3)

3.2 设置监视指标

如果在 model.fit() API中设置了 validation_datavalidation_split 参数,则会返回验证数据集上的损失,名称为 val_loss

可以在编译模型时通过 model.compile 函数的 metrics 参数来指定它们。此参数使用Python列表传入,例如 mse 表示均方误差,precision 表示精度。常用监视指标:

...
model.compile(..., metrics=['accuracy'])

如果在训练时监视其它指标,也可通过相同的名称提供给 metrics 参数,如 val_accuracy 表示验证集上的准确率。mse 表示训练集上的均方误差,val_mse 表示验证集上的均方误差。


3.3 EarlyStopping API

tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', min_delta=0, patience=0, verbose=0, mode='auto',
    baseline=None, restore_best_weights=False
)

参数说明:

  • monitor:监视指标。
  • min_delta:监视指标的最小变化(有改进的定义);即绝对变化小于 min_delta,不视为改进。
  • patience:没有改进的时期数,之后训练将停止。
  • verbose:在控制台打印信息的级别;verbose设置为1时,训练停止时返回训练周期数。
  • mode{"auto", "min", "max"} 中选其一。在 min 模式下,当监视的数量停止减少时,训练将停止;在max 模式下,当监视的数量停止增加时,它将停止;在 auto 模式下,将根据监视数量的名称自动推断出变化方向;
  • baseline:监视数量的基准值。如果模型没有超过基线的改善,则停止训练。
  • restore_best_weights:是否使用该训练周期中监视指标最佳模型所对应的权重。如果为 False,则使用在训练的最后一步获得的模型权重。

3.4 ModelCheckPoint API

tf.keras.callbacks.ModelCheckpoint(
    filepath, monitor='val_loss', verbose=0, save_best_only=False,
    save_weights_only=False, mode='auto', save_freq='epoch', **kwargs
)

用于以一定间隔保存模型或权重(在检查点文件中),之后可以重新加载模型或权重以从保存的状态继续训练。

  • filepath:字符串,保存模型文件的路径。filepath可以包含命名的格式选项,这会传递给该API的方法 on_epoch_end 中。例如:如果filepath为 weights.{epoch:02d}-{val_loss:.2f}.hdf5,则模型将以epoch和验证损失作为文件名保存。
  • monitor:监视指标。
  • verbose:在控制台打印信息的级别,0或1。
  • save_best_only:如果为 save_best_only=True,则最佳模型不会被覆盖。如果filepath不包含格式设置选项,原名称会被覆盖。
  • save_weights_only:如果为True,则仅保存模型的权重(model.save_weights(filepath)),否则保存完整的模型(model.save(filepath))。
  • save_freq'epoch'或整数。使用'epoch'时,回调函数会在每个时期后保存模型。使用整数时,回调将在许多批次结束时保存模型。请注意,如果保存未与时间段保持一致,则受监视的指标可能会不太可靠(它可能只反映1个批次,因为每个epoch都会重置该指标)。默认为’epoch’。

3.5 实例

# 定义模型
model = Sequential()
model.add(Dense(500, input_dim=2, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
# 编译模型
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# 定义回调
callbacks_set = [EarlyStopping(monitor='val_loss', mode='min', verbose=1),
				ModelCheckpoint('best_model.h5', monitor='val_accuracy', mode='max', verbose=1, save_best_only=True)]
# 训练模型
# 及时停止,同时保存验证集上性能最好(验证损失最低)的模型
history = model.fit(trainX, trainy, validation_data=(testX, testy), epochs=4000, verbose=0, 
					callbacks=callbacks_set)

tensorboard log 等等其它回调同理。


https://machinelearningmastery.com/early-stopping-to-avoid-overtraining-neural-network-models/
https://machinelearningmastery.com/how-to-stop-training-deep-neural-networks-at-the-right-time-using-early-stopping/
https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/EarlyStopping?hl=en
https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/ModelCheckpoint?hl=en

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章