Pytorch 自定義torch.autograd.Function

轉載自:
https://zhuanlan.zhihu.com/p/27783097
https://www.jianshu.com/p/5d5d3957f684
ReLu 函數求導示例:

# -*- coding:utf8 -*-

import torch
from torch.autograd import Variable

class MyReLU(torch.autograd.Function):

    def forward(self, input_):
        # 在forward中,需要定義MyReLU這個運算的forward計算過程
        # 同時可以保存任何在後向傳播中需要使用的變量值
        self.save_for_backward(input_)         # 將輸入保存起來,在backward時使用
        output = input_.clamp(min=0)           # relu就是截斷負數,讓所有負數等於0
        return output

    def backward(self, grad_output):
        # 根據BP算法的推導(鏈式法則),dloss / dx = (dloss / doutput) * (doutput / dx)
        # dloss / doutput就是輸入的參數grad_output、
        # 因此只需求relu的導數,在乘以grad_outpu
        input_, = self.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input_ < 0] = 0               # 上訴計算的結果就是左式。即ReLU在反向傳播中可以看做一個通道選擇函數,所有未達到閾值(激活值<0)的單元的梯度都爲0
        return grad_input

# Wrap一個ReLU函數
# 可以直接把剛纔自定義的ReLU類封裝成一個函數,方便直接調用
def relu(input_):
    # MyReLU()是創建一個MyReLU對象,
    # Function類利用了Python __call__操作,使得可以直接使用對象調用__call__制定的方法
    # __call__指定的方法是forward,因此下面這句MyReLU()(input_)相當於
    # return MyReLU().forward(input_)
    return MyReLU()(input_)

input_ = Variable(torch.linspace(-3, 3, steps=5))
print input_
print relu(input_)
# input_ = Variable(torch.randn(1))
# relu = MyReLU()
# output_ = relu(input_)
#
# # 這個relu對象,就是output_.creator,即這個relu對象將output與input連接起來,形成一個計算圖
# print relu
# print output_.creator

在pytorch庫中:

import math

import torch
import torch.nn as nn
from torch.autograd import Function


class ReLUF(Function):

    @staticmethod
    def forward(cxt, input):
        cxt.save_for_backward(input)

        output = input.clamp(min=0)

        return output

    @staticmethod
    def backward(cxt, grad_output):
        input, = cxt.saved_variables

        grad_input = grad_output.clone()
        grad_input[input < 0] = 0

        return grad_input


class LinearF(Function):

    @staticmethod
    def forward(cxt, input, weight, bias=None):
        cxt.save_for_backward(input, weight, bias)

        output = input.mm(weight.t())
        if bias is not None:
            output += bias

        return output

    @staticmethod
    def backward(cxt, grad_output):
        input, weight, bias = cxt.saved_variables

        grad_input = grad_weight = grad_bias = None
        if cxt.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if cxt.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and cxt.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)

        if bias is not None:
            return grad_input, grad_weight, grad_bias
        else:
            return grad_input, grad_weight


# aliases
relu = ReLUF.apply
linear = LinearF.apply


# simple test
if __name__ == "__main__":
    from torch.autograd import Variable

    torch.manual_seed(1111)
    a = torch.randn(2, 3)

    va = Variable(a, requires_grad=True)
    vb = relu(va)
    print va.data, vb.data

    vb.backward(torch.ones(va.size()))
    print va.grad.data

pytorch中文文檔的說明:
https://pytorch-cn.readthedocs.io/zh/latest/notes/extending/

其他博客講解:
https://blog.csdn.net/Hungryof/article/details/78346304
https://blog.csdn.net/u012436149/article/details/78829329
https://blog.csdn.net/tsq292978891/article/details/79364140

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章