torch.autograd.Function

pytorch自定義layer有兩種方式:

方式一:通過繼承torch.nn.Module類來實現拓展。只需重新實現__init__和forward函數。

方式二通過繼承torch.autograd.Function,除了要實現__init__和forward函數,還要實現backward函數(就是自定義求導規則)。

方式一看着更簡單一點,torch.nn.Module不香麼?爲毛要用方式二。因爲當我們自定義的函數torch.nn.functioanl裏沒有的時候,或者一些操作不可導,就需要自己定義求導方式,也就是所謂的Extending torch.autograd

官方的示例是這個樣子的: Class  torch.autograd.Function

>>> class Exp(Function):
>>>
>>>     @staticmethod
>>>     def forward(ctx, i):
>>>         result = i.exp()
>>>         ctx.save_for_backward(result)
>>>         return result
>>>
>>>     @staticmethod
>>>     def backward(ctx, grad_output):
>>>         result, = ctx.saved_tensors
>>>         return grad_output * result

詳細學習參見這個博客講的很好https://blog.csdn.net/qq_27825451/article/details/95189376

發佈了25 篇原創文章 · 獲贊 8 · 訪問量 3萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章