pytorch模型剪枝學習筆記

pytorch代碼倉庫

pytorch在19年11月份的時候合入了這部分剪枝的代碼。pytorch提供一些直接可用的api,用戶只需要傳入需要剪枝的module實例和需要剪枝的參數名字,系統自動幫助完成剪枝操作,看起來接口挺簡單。比如 def random_structured(module, name, amount, dim)

pytorch支持的幾種類型的剪枝策略:

詳細分析

  • pytorch提供了一個剪枝的抽象基類‘‘class BasePruningMethod(ABC)’,所有剪枝策略都需要繼承該基類,並重載部分函數就可以了

  • 一般情況下需要重載init和compute_mask方法,call, apply_mask, apply, prune和remove不需要重載,例如官方提供的RandomUnstructured剪枝方法 file

  • 基類實現的6個方法: file

  • 剪枝的API接口,可以看到支持用戶自定義的剪枝mask,接口爲custom_from_mask file

  • API的實現,使用classmethod的方法,剪枝策略的實例化在框架內部完成,不需要用戶實例化

  • 剪枝的大隻過程:

    1. 根據用戶選擇的剪枝API生成對應的策略實例,此時會判斷需要做剪枝操作的module上是否已經掛有前向回調函數,沒有則生成新的,有了就在老的上面添加,並且生成PruningContainer。從這裏可以看出,對於同一個module使用多個剪枝策略時,pytorch通過PruningContainer來對剪枝策略進行管理。PruningContainer本身也是繼承自BasePruningMethod。同時設置前向計算的回調,便於後續訓練時調用。
    2. 接着根據用戶輸入的module和name,找到對應的參數tensor。如果是第一次剪枝,那麼需要生成_orig結尾的tensor,然後刪除原始的module上的tensor。如name爲bias,那麼生成bias_orig存起來,然後刪除module.bias屬性。
    3. 獲取defaultmask,然後調用method.computemask生成當前策略的mask值。生成的mask會被存在特定的緩存module.register_buffer(name + "_mask", mask)。這裏的compute_mask可能是兩種情況:如果只有一個策略,那麼調用的時候對應剪枝策略的compute_mask方法,如果一個module有多個剪枝策略組合,那麼調用的應該是PruningContainer的compute_mask file
    4. 執行剪枝,保存剪枝結果到module的屬性,註冊訓練時的剪枝回調函數,剪枝完成。新的mask應用在orig的tensor上面生成新的tensor保存的對應的name屬性 file
  • remove接口 pytorch還提供各類一個remove接口,目的是把之前的剪枝結果持久化,具體操作就是刪除之前生成的跟剪枝相關的緩存或者是回調hook接口,設置被剪枝的name參數(如bias)爲最後一次訓練的值。 file

  • 自己寫一個剪枝策略接口也是可以的: file

    1. 先寫一個剪枝策略類繼承BasePruningMethod
    2. 然後重載基類的compute_mask方法,寫自己的計算mask方法

官方完整教程在這裏

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章