pytorch的四個hook函數

  訓練神經網絡模型有時需要觀察模型內部模塊的輸入輸出,或是期望在不修改原始模塊結構的情況下調整中間模塊的輸出,pytorch可以用hook回調函數來實現這一功能。主要使用四個hook註冊函數:register_forward_hook、register_forward_pre_hook、register_full_backward_hook、register_full_backward_pre_hook。這四個函數可以被繼承nn.Module的任意模塊調用,傳入hook函數並進行註冊,從而在執行該模塊的相應階段調用hook函數實現所需功能。

register_forward_hook(self, hook, *, prepend, with_kwargs)

  爲模塊註冊一個在該模塊前向傳播之後執行的回調函數。

  hook(module, args, output):需執行的回調函數對象,module爲當前模塊引用,args爲當前模塊前向傳播輸入,output爲當前模塊前向傳播輸出。可以返回修改後的output來修改該模塊前向傳播輸出。

  prepend:將該hook函數放在回調函數列表最前面,從而最先執行,否則放在隊列最後。

  with_kwargs:hook函數是否傳入關鍵字參數,如果爲True,則hook可以額外增加關鍵則參數。

  register_forward_hook註冊函數本身返回一個handle句柄,可執行handle.remove()將註冊的該hook函數移除。

register_forward_pre_hook(self, hook, *, prepend, with_kwargs)

  爲模塊註冊一個在該模塊前向傳播之前執行的回調函數。

  hook(module, args):args爲該模塊前向傳播輸入。可以返回修改後的args來修改該模塊前向傳播輸入。

  其它參數、特性與前面一致。

register_full_backward_hook(self, hook, prepend)

  爲模塊註冊一個在該模塊反向傳播之後執行的回調函數。

  hook(module, grad_input, grad_output):grad_input與grad_output分別爲該模塊前向傳播輸入和輸出的梯度。可以返回修改後的grad_input來修改該模塊前向傳播輸入的梯度。

register_full_backward_pre_hook(self, hook, prepend)

  爲模塊註冊一個在該模塊反向傳播之前執行的回調函數。

  hook(module, grad_output):grad_output爲該模塊前向傳播輸出的梯度。可以返回修改後的grad_output來修改這一梯度。

 
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章