pytorch報錯:RuntimeError: Given groups=1, weight of size 10 3 3 3, expected input[1, 12, 12, 3]

在mtcnn人臉檢測中,網絡訓練時報以下錯誤:
在這裏插入圖片描述
後來找到原因如下:
圖片經過處理後的數據格式是 NHWC,而pytorch輸入圖片的格式要求是NCHW,需要轉化一下。有兩種方法:
一是用下列方法:

import torch
image = torch.randn(1,12 ,12, 3)
image = image.permute(0,3,1,2)
print(image.shape)#([1, 3, 12, 12])

二是用torchvision.transform.ToTensor(),這個函數會自動將numpy處理圖片後的NHWC轉化爲NCHW,並進行(0,1)的歸一化,非常方便。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章