pytorch:data讀取出錯:RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension

     在使用Dataloader讀取數據的時候,使用batch_size=1不會出現這個問題。當batch_size>1時, 默認將會使用torch.stack()爲你生成一個[batch,x, x, x] 的tensor數據,在使用該函數時需要輸入的兩個tensor維度一樣。

1、注意自己圖像的大小是否resize到相同尺寸

2、圖像的通道數是否相同,全綵色或全灰度,使用相同的類型。

3、如果是CNN的多標籤分類,標籤長度不同也可以用下面的辦法。

3、在進行目標檢測的時候,我們返回的不僅僅是圖像數據,還有它的gt_box以及gt_label。但是每個圖像的標籤個數不是相同的,所以在使用Dataloader原始的參數時也會報錯。此時可以自己寫一個 collate_fn函數,因爲我的返回爲圖像、目標框以及類別標籤所以修改如下

def data_collate(batch):
    gt_box = []
    gt_label = []
    imgs = []
    for info in batch:
        imgs.append(info[0])
        gt_box.append(info[1])
        gt_label.append(info[2])
    return torch.stack(imgs, 0), gt_box,gt_label

train裏面調用:DataLoader(data_mine, batch_size=2,collate_fn=data_collate, shuffle=True, num_workers=2)

根據自己重寫的Dataset類的返回值修改上面,得到自己想要的數據。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章