用torchsummary打印pytorch模型參數信息,並計算模型的FLOPs

打印模型參數信息

在python3環境下安裝torchsummary

from torchsummary import summary
import torchvision.models as models
model = models.resnet152()
model = model.cuda()
summary(model, input_size=(3,224,224), batch_size=-1, device='cuda')

計算模型FLOPs

代碼詳見:https://github.com/TangShengqin/pytorch_learn/blob/master/model_flops.py
其中,multiply_adds = True會同時累計加法和乘法的計算量。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章