PyTorch模型參數統計問題
使用torchsummary
使用torchsummary
可以方便的統計模型的參數量
from torchsummary import summary
model = net().cuda()
input_shape = (1,256,256)
summary(model, input_shape)
即可打印出網絡結構和參數統計
但是有一個問題:對於共享參數的模塊(說白了就是重複調用的模塊)會重複統計
這裏有一篇知乎專欄文章很直觀地分析了這個問題。
以後用torchsummary
統計參數時如若遇到參數量爆炸的情況可以用下面的方法進行統計,不過兩個還是結合着來用比較好,畢竟torchsummary
還可以打印網絡結構,比較直觀。
def count_parameters(model):
return sum(p.numel() for p in model.parameters()
if p.requires_grad)
看一下對比: