Keras深度學習庫包括三個獨立的函數,可用於訓練您自己的模型:
1.Keras的.fit,.fit_generator和.train_on_batch函數之間的區別
2.在訓練自己的深度學習模型時,何時使用每個函數
3.如何實現自己的Keras數據生成器,並在使用.fit_generator訓練模型時使用它
4.在訓練完成後評估網絡時,如何使用.predict_generator函數
fit:
model.fit(trainX, trainY, batch_size=32, epochs=50)
在這裏您可以看到我們提供的訓練數據(trainX)和訓練標籤(trainY)。
然後,我們指示Keras允許我們的模型訓練50個epoch,同時batch size爲32。
對.fit的調用在這裏做出兩個主要假設:
我們的整個訓練集可以放入RAM
沒有數據增強(即不需要Keras生成器)
相反,我們的網絡將在原始數據上訓練。
原始數據本身將適合內存,我們無需將舊批量數據從RAM中移出並將新批量數據移入RAM。
此外,我們不會使用數據增強動態操縱訓練數據。
Keras fit_generator:
對於小型,簡單化的數據集,使用Keras的.fit函數是完全可以接受的。
這些數據集通常不是很具有挑戰性,不需要任何數據增強。
但是,真實世界的數據集很少這麼簡單:
真實世界的數據集通常太大而無法放入內存中
它們也往往具有挑戰性,要求我們執行數據增強以避免過擬合併增加我們的模型的泛化能力
在這些情況下,我們需要利用Keras的.fit_generator函數:
# initialize the number of epochs and batch size
EPOCHS = 100
BS = 32
# construct the training image generator for data augmentation
aug = ImageDataGenerator(rotation_range=20, zoom_range=0.15,
width_shift_range=0.2, height_shift_range=0.2, shear_range=0.15,
horizontal_flip=True, fill_mode="nearest")
# train the network
H = model.fit_generator(aug.flow(trainX, trainY, batch_size=BS),
validation_data=(testX, testY), steps_per_epoch=len(trainX) // BS,
epochs=EPOCHS)
執行數據增強是正則化的一種形式,使我們的模型能夠更好的被泛化。
但是,應用數據增強意味着我們的訓練數據不再是“靜態的” ——數據不斷變化。
根據提供給ImageDataGenerator的參數隨機調整每批新數據。
因此,我們現在需要利用Keras的.fit_generator函數來訓練我們的模型。
顧名思義,.fit_generator函數假定存在一個爲其生成數據的基礎函數。
該函數本身是一個Python生成器。
Keras在使用.fit_generator訓練模型時的過程:
- Keras調用提供給
.fit_generator
的生成器函數(在本例中爲aug.flow
) - 生成器函數爲
.fit_generator
函數生成一批大小爲BS
的數據 .fit_generator
函數接受批量數據,執行反向傳播,並更新模型中的權重- 重複該過程直到達到期望的epoch數量
您會注意到我們現在需要在調用.fit_generator時提供steps_per_epoch參數(.fit方法沒有這樣的參數)。
爲什麼我們需要steps_per_epoch?
請記住,Keras數據生成器意味着無限循環,它永遠不會返回或退出。
由於該函數旨在無限循環,因此Keras無法確定一個epoch何時開始的,並且新的epoch何時開始。
因此,我們將訓練數據的總數除以批量大小的結果作爲steps_per_epoch的值。一旦Keras到達這一步,它就會知道這是一個新的epoch。
Keras train_on_batch函數:
對於尋求對Keras模型進行精細控制( finest-grained control)的深度學習實踐者,您可能希望使用
.train_on_batch
函數:
model.train_on_batch(batchX, batchY)
train_on_batch函數接受單批數據,執行反向傳播,然後更新模型參數。
該批數據可以是任意大小的(即,它不需要提供明確的批量大小)。
您也可以生成數據。此數據可以是磁盤上的原始圖像,也可以是以某種方式修改或擴充的數據。
當您有非常明確的理由想要維護自己的訓練數據迭代器時,通常會使用.train_on_batch函數,例如數據迭代過程非常複雜並且需要自定義代碼。
如果你發現自己在詢問是否需要.train_on_batch函數,那麼很有可能你可能不需要。
在99%的情況下,您不需要對訓練深度學習模型進行如此精細的控制。相反,您可能只需要自定義Keras .fit_generator函數。
也就是說,如果你需要它,知道存在這個函數是很好的。
如果您是一名高級深度學習從業者/工程師,並且您確切知道自己在做什麼以及爲什麼這樣做,我通常只建議使用.train_on_batch函數。