1 問題描述
在使用PyTorch編程的時候,經常遇到一種報錯就是:“RuntimeError: running_mean should contain *** elements not ***”;
這次我具體的報錯信息是:
File "/home/songyuc/software/python/anaconda/anaconda3/envs/conda-general/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "/home/songyuc/software/python/anaconda/anaconda3/envs/conda-general/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward
input = module(input)
File "/home/songyuc/software/python/anaconda/anaconda3/envs/conda-general/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "/home/songyuc/software/python/anaconda/anaconda3/envs/conda-general/lib/python3.7/site-packages/torch/nn/modules/batchnorm.py", line 81, in forward
exponential_average_factor, self.eps)
File "/home/songyuc/software/python/anaconda/anaconda3/envs/conda-general/lib/python3.7/site-packages/torch/nn/functional.py", line 1670, in batch_norm
training, momentum, eps, torch.backends.cudnn.enabled
RuntimeError: running_mean should contain 192 elements not 768
從最後一行的報錯信息,可以看到:進行求均值元素的總數應該是192而不是768;
2 解決方案
我們可以繼續看看上一條提示信息:“File "/home/songyuc/software/python/anaconda/anaconda3/envs/conda-general/lib/python3.7/site-packages/torch/nn/functional.py", line 1670, in batch_norm”
有一個值得注意的信息是batch_norm,而我們的模型中也剛好使用了BN的操作,所以應該是BN的設置出現了問題,
我們回到代碼定位的部分進行查看,需要查看的是BN初始化設置的代碼,然後看到了下面的代碼:
modules = [nn.Sequential( nn.Conv2d(in_channels, OUT_CHANNELS, 1, groups=1, bias=False), nn.BatchNorm2d(in_channels), nn.ReLU()),
我們可以看到,果然,BatchNorm2d的輸入通道數與前一層Conv2d的輸出通道數不一致,而這裏的OUT_CHANNELS=192,in_channels=192,所以造成了這種維度的不一致,所以纔會報錯;
所以,我們需要根據自己模型的設計,將BN層與Conv層的輸出維度保持一致。