pytorch报错:RuntimeError: Given groups=1, weight of size 10 3 3 3, expected input[1, 12, 12, 3]

在mtcnn人脸检测中,网络训练时报以下错误:
在这里插入图片描述
后来找到原因如下:
图片经过处理后的数据格式是 NHWC,而pytorch输入图片的格式要求是NCHW,需要转化一下。有两种方法:
一是用下列方法:

import torch
image = torch.randn(1,12 ,12, 3)
image = image.permute(0,3,1,2)
print(image.shape)#([1, 3, 12, 12])

二是用torchvision.transform.ToTensor(),这个函数会自动将numpy处理图片后的NHWC转化为NCHW,并进行(0,1)的归一化,非常方便。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章