用torchsummary打印pytorch模型参数信息,并计算模型的FLOPs

打印模型参数信息

在python3环境下安装torchsummary

from torchsummary import summary
import torchvision.models as models
model = models.resnet152()
model = model.cuda()
summary(model, input_size=(3,224,224), batch_size=-1, device='cuda')

计算模型FLOPs

代码详见:https://github.com/TangShengqin/pytorch_learn/blob/master/model_flops.py
其中,multiply_adds = True会同时累计加法和乘法的计算量。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章