在pytorch訓練數據,發現迭代到某一個次數時,就會報錯,大概意思是輸入的數據的batch size變了,不是預設置的了,後來發現是在DataLoader中有一個參數,控制dataset中的數據個數不是batch_size的整數倍時,剩下的不足batch size個的數據是否會被丟棄。
DataLoader的函數定義如下:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False,
drop_last=False)
drop_last:dataset中的數據個數可能不是batch_size的整數倍,drop_last爲True會將多出來不足一個batch的數據丟棄。
所以就在代碼里加上了這個參數爲True,繼續訓練就不再報錯了。