Keras深度学习库包括三个独立的函数,可用于训练您自己的模型:
1.Keras的.fit,.fit_generator和.train_on_batch函数之间的区别
2.在训练自己的深度学习模型时,何时使用每个函数
3.如何实现自己的Keras数据生成器,并在使用.fit_generator训练模型时使用它
4.在训练完成后评估网络时,如何使用.predict_generator函数
fit:
model.fit(trainX, trainY, batch_size=32, epochs=50)
在这里您可以看到我们提供的训练数据(trainX)和训练标签(trainY)。
然后,我们指示Keras允许我们的模型训练50个epoch,同时batch size为32。
对.fit的调用在这里做出两个主要假设:
我们的整个训练集可以放入RAM
没有数据增强(即不需要Keras生成器)
相反,我们的网络将在原始数据上训练。
原始数据本身将适合内存,我们无需将旧批量数据从RAM中移出并将新批量数据移入RAM。
此外,我们不会使用数据增强动态操纵训练数据。
Keras fit_generator:
对于小型,简单化的数据集,使用Keras的.fit函数是完全可以接受的。
这些数据集通常不是很具有挑战性,不需要任何数据增强。
但是,真实世界的数据集很少这么简单:
真实世界的数据集通常太大而无法放入内存中
它们也往往具有挑战性,要求我们执行数据增强以避免过拟合并增加我们的模型的泛化能力
在这些情况下,我们需要利用Keras的.fit_generator函数:
# initialize the number of epochs and batch size
EPOCHS = 100
BS = 32
# construct the training image generator for data augmentation
aug = ImageDataGenerator(rotation_range=20, zoom_range=0.15,
width_shift_range=0.2, height_shift_range=0.2, shear_range=0.15,
horizontal_flip=True, fill_mode="nearest")
# train the network
H = model.fit_generator(aug.flow(trainX, trainY, batch_size=BS),
validation_data=(testX, testY), steps_per_epoch=len(trainX) // BS,
epochs=EPOCHS)
执行数据增强是正则化的一种形式,使我们的模型能够更好的被泛化。
但是,应用数据增强意味着我们的训练数据不再是“静态的” ——数据不断变化。
根据提供给ImageDataGenerator的参数随机调整每批新数据。
因此,我们现在需要利用Keras的.fit_generator函数来训练我们的模型。
顾名思义,.fit_generator函数假定存在一个为其生成数据的基础函数。
该函数本身是一个Python生成器。
Keras在使用.fit_generator训练模型时的过程:
- Keras调用提供给
.fit_generator
的生成器函数(在本例中为aug.flow
) - 生成器函数为
.fit_generator
函数生成一批大小为BS
的数据 .fit_generator
函数接受批量数据,执行反向传播,并更新模型中的权重- 重复该过程直到达到期望的epoch数量
您会注意到我们现在需要在调用.fit_generator时提供steps_per_epoch参数(.fit方法没有这样的参数)。
为什么我们需要steps_per_epoch?
请记住,Keras数据生成器意味着无限循环,它永远不会返回或退出。
由于该函数旨在无限循环,因此Keras无法确定一个epoch何时开始的,并且新的epoch何时开始。
因此,我们将训练数据的总数除以批量大小的结果作为steps_per_epoch的值。一旦Keras到达这一步,它就会知道这是一个新的epoch。
Keras train_on_batch函数:
对于寻求对Keras模型进行精细控制( finest-grained control)的深度学习实践者,您可能希望使用
.train_on_batch
函数:
model.train_on_batch(batchX, batchY)
train_on_batch函数接受单批数据,执行反向传播,然后更新模型参数。
该批数据可以是任意大小的(即,它不需要提供明确的批量大小)。
您也可以生成数据。此数据可以是磁盘上的原始图像,也可以是以某种方式修改或扩充的数据。
当您有非常明确的理由想要维护自己的训练数据迭代器时,通常会使用.train_on_batch函数,例如数据迭代过程非常复杂并且需要自定义代码。
如果你发现自己在询问是否需要.train_on_batch函数,那么很有可能你可能不需要。
在99%的情况下,您不需要对训练深度学习模型进行如此精细的控制。相反,您可能只需要自定义Keras .fit_generator函数。
也就是说,如果你需要它,知道存在这个函数是很好的。
如果您是一名高级深度学习从业者/工程师,并且您确切知道自己在做什么以及为什么这样做,我通常只建议使用.train_on_batch函数。