profile計算模型參數

from thop import profile class Test(nn.Module): def __init__(self, input_size, output_szie): super(Test, self).__init__() self.out = nn.Linear(input_size, output_szie) def forward(self, x): output = self.out(x) return output t = Test(10, 2) x = torch.randn(4, 10) profile(t, (x,), verbose=False) # (80.0, 22.0): 10*2 + 2 = 22.0 # total_flops += flops # model_params_num += params
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章