Keras训练模型

Keras深度学习库包括三个独立的函数,可用于训练您自己的模型:

1.Keras的.fit,.fit_generator和.train_on_batch函数之间的区别
2.在训练自己的深度学习模型时,何时使用每个函数
3.如何实现自己的Keras数据生成器,并在使用.fit_generator训练模型时使用它
4.在训练完成后评估网络时,如何使用.predict_generator函数

fit:

model.fit(trainX, trainY, batch_size=32, epochs=50)

在这里您可以看到我们提供的训练数据(trainX)和训练标签(trainY)。

然后,我们指示Keras允许我们的模型训练50个epoch,同时batch size为32。

对.fit的调用在这里做出两个主要假设:

我们的整个训练集可以放入RAM
没有数据增强(即不需要Keras生成器)
相反,我们的网络将在原始数据上训练。

原始数据本身将适合内存,我们无需将旧批量数据从RAM中移出并将新批量数据移入RAM。

此外,我们不会使用数据增强动态操纵训练数据。

Keras fit_generator:

对于小型,简单化的数据集,使用Keras的.fit函数是完全可以接受的。

这些数据集通常不是很具有挑战性,不需要任何数据增强。

但是,真实世界的数据集很少这么简单:

真实世界的数据集通常太大而无法放入内存中
它们也往往具有挑战性,要求我们执行数据增强以避免过拟合并增加我们的模型的泛化能力
在这些情况下,我们需要利用Keras的.fit_generator函数:

# initialize the number of epochs and batch size
EPOCHS = 100
BS = 32

# construct the training image generator for data augmentation
aug = ImageDataGenerator(rotation_range=20, zoom_range=0.15,
	width_shift_range=0.2, height_shift_range=0.2, shear_range=0.15,
	horizontal_flip=True, fill_mode="nearest")

# train the network
H = model.fit_generator(aug.flow(trainX, trainY, batch_size=BS),
	validation_data=(testX, testY), steps_per_epoch=len(trainX) // BS,
	epochs=EPOCHS)

执行数据增强是正则化的一种形式,使我们的模型能够更好的被泛化。

但是,应用数据增强意味着我们的训练数据不再是“静态的” ——数据不断变化。

根据提供给ImageDataGenerator的参数随机调整每批新数据。

因此,我们现在需要利用Keras的.fit_generator函数来训练我们的模型。

顾名思义,.fit_generator函数假定存在一个为其生成数据的基础函数。

该函数本身是一个Python生成器。

Keras在使用.fit_generator训练模型时的过程:

  • Keras调用提供给.fit_generator的生成器函数(在本例中为aug.flow
  • 生成器函数为.fit_generator函数生成一批大小为BS的数据
  • .fit_generator函数接受批量数据,执行反向传播,并更新模型中的权重
  • 重复该过程直到达到期望的epoch数量

您会注意到我们现在需要在调用.fit_generator时提供steps_per_epoch参数(.fit方法没有这样的参数)。

为什么我们需要steps_per_epoch?

请记住,Keras数据生成器意味着无限循环,它永远不会返回或退出。

由于该函数旨在无限循环,因此Keras无法确定一个epoch何时开始的,并且新的epoch何时开始。

因此,我们将训练数据的总数除以批量大小的结果作为steps_per_epoch的值。一旦Keras到达这一步,它就会知道这是一个新的epoch。

Keras train_on_batch函数:

对于寻求对Keras模型进行精细控制( finest-grained control)的深度学习实践者,您可能希望使用

.train_on_batch

函数:

model.train_on_batch(batchX, batchY)

train_on_batch函数接受单批数据,执行反向传播,然后更新模型参数。

该批数据可以是任意大小的(即,它不需要提供明确的批量大小)。

您也可以生成数据。此数据可以是磁盘上的原始图像,也可以是以某种方式修改或扩充的数据。

当您有非常明确的理由想要维护自己的训练数据迭代器时,通常会使用.train_on_batch函数,例如数据迭代过程非常复杂并且需要自定义代码。

如果你发现自己在询问是否需要.train_on_batch函数,那么很有可能你可能不需要。

在99%的情况下,您不需要对训练深度学习模型进行如此精细的控制。相反,您可能只需要自定义Keras .fit_generator函数。

也就是说,如果你需要它,知道存在这个函数是很好的。

如果您是一名高级深度学习从业者/工程师,并且您确切知道自己在做什么以及为什么这样做,我通常只建议使用.train_on_batch函数。
 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章