Keras訓練模型

Keras深度學習庫包括三個獨立的函數,可用於訓練您自己的模型:

1.Keras的.fit,.fit_generator和.train_on_batch函數之間的區別
2.在訓練自己的深度學習模型時,何時使用每個函數
3.如何實現自己的Keras數據生成器,並在使用.fit_generator訓練模型時使用它
4.在訓練完成後評估網絡時,如何使用.predict_generator函數

fit:

model.fit(trainX, trainY, batch_size=32, epochs=50)

在這裏您可以看到我們提供的訓練數據(trainX)和訓練標籤(trainY)。

然後,我們指示Keras允許我們的模型訓練50個epoch,同時batch size爲32。

對.fit的調用在這裏做出兩個主要假設:

我們的整個訓練集可以放入RAM
沒有數據增強(即不需要Keras生成器)
相反,我們的網絡將在原始數據上訓練。

原始數據本身將適合內存,我們無需將舊批量數據從RAM中移出並將新批量數據移入RAM。

此外,我們不會使用數據增強動態操縱訓練數據。

Keras fit_generator:

對於小型,簡單化的數據集,使用Keras的.fit函數是完全可以接受的。

這些數據集通常不是很具有挑戰性,不需要任何數據增強。

但是,真實世界的數據集很少這麼簡單:

真實世界的數據集通常太大而無法放入內存中
它們也往往具有挑戰性,要求我們執行數據增強以避免過擬合併增加我們的模型的泛化能力
在這些情況下,我們需要利用Keras的.fit_generator函數:

# initialize the number of epochs and batch size
EPOCHS = 100
BS = 32

# construct the training image generator for data augmentation
aug = ImageDataGenerator(rotation_range=20, zoom_range=0.15,
	width_shift_range=0.2, height_shift_range=0.2, shear_range=0.15,
	horizontal_flip=True, fill_mode="nearest")

# train the network
H = model.fit_generator(aug.flow(trainX, trainY, batch_size=BS),
	validation_data=(testX, testY), steps_per_epoch=len(trainX) // BS,
	epochs=EPOCHS)

執行數據增強是正則化的一種形式,使我們的模型能夠更好的被泛化。

但是,應用數據增強意味着我們的訓練數據不再是“靜態的” ——數據不斷變化。

根據提供給ImageDataGenerator的參數隨機調整每批新數據。

因此,我們現在需要利用Keras的.fit_generator函數來訓練我們的模型。

顧名思義,.fit_generator函數假定存在一個爲其生成數據的基礎函數。

該函數本身是一個Python生成器。

Keras在使用.fit_generator訓練模型時的過程:

  • Keras調用提供給.fit_generator的生成器函數(在本例中爲aug.flow
  • 生成器函數爲.fit_generator函數生成一批大小爲BS的數據
  • .fit_generator函數接受批量數據,執行反向傳播,並更新模型中的權重
  • 重複該過程直到達到期望的epoch數量

您會注意到我們現在需要在調用.fit_generator時提供steps_per_epoch參數(.fit方法沒有這樣的參數)。

爲什麼我們需要steps_per_epoch?

請記住,Keras數據生成器意味着無限循環,它永遠不會返回或退出。

由於該函數旨在無限循環,因此Keras無法確定一個epoch何時開始的,並且新的epoch何時開始。

因此,我們將訓練數據的總數除以批量大小的結果作爲steps_per_epoch的值。一旦Keras到達這一步,它就會知道這是一個新的epoch。

Keras train_on_batch函數:

對於尋求對Keras模型進行精細控制( finest-grained control)的深度學習實踐者,您可能希望使用

.train_on_batch

函數:

model.train_on_batch(batchX, batchY)

train_on_batch函數接受單批數據,執行反向傳播,然後更新模型參數。

該批數據可以是任意大小的(即,它不需要提供明確的批量大小)。

您也可以生成數據。此數據可以是磁盤上的原始圖像,也可以是以某種方式修改或擴充的數據。

當您有非常明確的理由想要維護自己的訓練數據迭代器時,通常會使用.train_on_batch函數,例如數據迭代過程非常複雜並且需要自定義代碼。

如果你發現自己在詢問是否需要.train_on_batch函數,那麼很有可能你可能不需要。

在99%的情況下,您不需要對訓練深度學習模型進行如此精細的控制。相反,您可能只需要自定義Keras .fit_generator函數。

也就是說,如果你需要它,知道存在這個函數是很好的。

如果您是一名高級深度學習從業者/工程師,並且您確切知道自己在做什麼以及爲什麼這樣做,我通常只建議使用.train_on_batch函數。
 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章