日萌社
人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度學習實戰(不定時更新)
1.使用repeat() 解決如下報錯:
WARNING:tensorflow:Your input ran out of data; interrupting training.
Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs`
batches (in this case, 414 batches). You may need to use the repeat() function
when building your dataset.
警告:tensorflow:輸入的數據用完;中斷訓練。請確保您的數據集或生成器至少可以生成
“每個steps_per_epoch*epoch”批(在本例中爲414 batches)。在構建數據集時,可能需要使用repeat()函數。
2.解決:
方式一:
如果使用 model.fit(train_dataset, validation_data=test_dataset, steps_per_epoch=M, epochs=N),
即同時steps_per_epoch和epochs的話,則要求把訓練集和驗證集都拷貝N份(epochs份),
因爲使用了steps_per_epoch之後,每次epoch的遍歷是對不同份的訓練集和驗證集進行遍歷,
不再是對同一份訓練集和驗證集進行遍歷,因此一開始就要拷貝N份(epochs份)的源訓練集和源驗證集,
供以遍歷epoch次。可以使用repeat(epochs)對數據集進行拷貝epochs份。
方式二:
如果使用 model.fit(train_dataset, validation_data=test_dataset, epochs=N),
只使用了epochs的話,那麼每次epoch的遍歷都是對同一份訓練集和驗證集進行遍歷,便不需要把訓練集和驗證集都拷貝N份(epochs份),
其中steps_per_epoch值不需要傳入,即會在第一次epoch的遍歷中自動計算steps_per_epoch值。
3.例子:batch、repeat、steps_per_epoch、epochs的使用
1.先用 batch(批次大小) 然後才用 repeat(重複次數),
比如:repeat(2)重複數據集2次,即複製一份數據集,最終即有兩份數據集。
2.steps_per_epoch 即表示 一個epoch 裏面遍歷批量數據的次數,即遍歷多少個批量數據完成一個epoch。
3.應保證要輸入到模型的數據集(包括訓練集/驗證集)的批量個數都均爲steps_per_epoch*epoch。
4.fit中沒有定義steps_per_epoch,只定義了epochs的話,那麼只會對同一份數據集進行遍歷epochs次進行訓練。
如果fit中同時有傳訓練集和驗證集validation_data進行訓練/驗證的話,那麼均爲對同一份訓練集和驗證集進行遍歷epochs次進行訓練/驗證。
打印的訓練信息格式如下:當前步數step/總步數steps - ETA:剩餘訓練時間 - loss - accuracy
5.fit中同時定義了steps_per_epoch和epochs的話,那麼表示對數據集進行遍歷epochs次進行訓練,並且每個epoch中遍歷steps_per_epoch個批量數據。
但要注意的是此處所說的對數據集進行遍歷epochs次指的不是對同一份數據集遍歷epochs次,而是對epochs份數據集遍歷epochs次,
而每份數據集遍歷1次,因此需要對原始數據集拷貝epochs份,才能每份數據集遍歷1次,一共遍歷epochs次進行訓練。
如果fit中同時有傳訓練集和驗證集validation_data進行訓練/驗證的話,那麼同時要把訓練集和驗證集validation_data都拷貝epochs份。
可以使用repeat(epochs)對數據集進行拷貝epochs份。
第一個epoch打印的訓練信息格式如下:
當前步數step/Unknown - ETA:剩餘訓練時間 - loss: - accuracy:
第一個epoch之後的每個epoch打印的訓練信息格式如下:(因爲經過第一個epoch之後底層就已經計算好每個epoch需要遍歷多少個批量數據,即得出總步數steps)
當前步數step/總步數steps - ETA:剩餘訓練時間 - loss - accuracy
4.使用順序:from_tensor_slices -> map -> shuffle -> batch -> repeat -> prefetch