使用Keras 的Model.fit_generator報錯StopIteration

使用Keras 的Model.fit_generator報錯StopIteration

之前也遇到過這個問題,解決了之後沒記下來,最近跑之前代碼又出現這個,廢了時間去找答案,還是要勤勞點做學習記錄纔行。

報錯如下,問題就是批量產生的數據沒有成功一批批地導入。

Epoch 1/100

Epoch 00001: CosineAnnealingScheduler setting learning rate to 0.001.
    1/20843 [..............................] - ETA: 145:53:01 - loss: 13.0753 - acc: 0.0000e+00Traceback (most recent call last):
  File "train-v2-20191123.py", line 311, in <module>
    main()
  File "train-v2-20191123.py", line 275, in main
    verbose=1)
  File "/home/jiajie/anaconda3/envs/py35/lib/python3.5/site-packages/keras/legacy/interfaces.py", line 91, in wrapper
    return func(*args, **kwargs)
  File "/home/jiajie/anaconda3/envs/py35/lib/python3.5/site-packages/keras/engine/training.py", line 1415, in fit_generator
    initial_epoch=initial_epoch)
  File "/home/jiajie/anaconda3/envs/py35/lib/python3.5/site-packages/keras/engine/training_generator.py", line 177, in fit_generator
    generator_output = next(output_generator)
  File "/home/jiajie/anaconda3/envs/py35/lib/python3.5/site-packages/keras/utils/data_utils.py", line 785, in get
    raise StopIteration()
StopIteration

解決方法如下:
要注意兩點:

  • 第一點是:要在循環輸出部分加上while 1 :
  • 第二點是:計算數據累計次數的參數要記得歸零,即下面代碼中的cnt
cnt=0
x=[]
y=[]
while 1for path in dirpath:
        X = cv2.imread(picpath,cv2.IMREAD_COLOR).astype(np.float32) / 255
        label = getTwoDimensionListIndex(gt2,name)
        x.append(X)
        y.append(label)
        cnt += 1
        if cnt == batch_size :
            cnt=0  #don't forget to set this number to 0
            yield (np.array(x),np.array(y))
            x=[]
            y=[]

基本問題就是上面兩點,不要漏了就能解決StopIteration的問題了。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章