Neural Ordinary Differential Equation 神經常微分方程(Neural ODEs)

  用微分方程的視角來看待和理解神經網絡是一種新的視角,該觀點最早出現在2016年鄂維南院士的一篇proposal裏:A Proposal on Machine Learning via Dynamical Systems.

Motivation

The core idea is that certain types of neural networks are analogous to a discretized differential equation, so maybe using off-the-shelf differential equation solvers will help get better results.

主要思想是:特定類型的神經網絡可以看作離散的微分方程,所以使用現成的微分方程求解器可以幫助獲得更好的結果。

First to see the contribution described in the original paper: “We introduce a new family of deep neural network models. Instead of specifying a discrete sequence of hidden layers, we parameterize the derivative of the hidden state using a neural network.”

先來看看原文中怎樣描述這個貢獻: “我們提出了一族新的神經網絡模型…”。

不是指定一個離散序列,我們參數化了網絡隱藏狀態的導數。

Why we should to parameterize the derivative of the hidden state of the neural network? The answer is we should capture the characteristic of the middle layer of the neural network. Here, the derivative of the hidden layer is equal to the gradient in the backpropagate progress.

爲什麼參數化網絡隱藏狀態的導數,也就是中間層的導數,因爲要建立隱藏狀態的微分方程。中間層的導數不就是網絡的梯度嗎?

如果直接將中間層的結果求解出來,是否時避免了反向傳播過程?


Reverse-mode automatic differentiation of ODE solutions

反向模式的自動微分ODE的解決方案

Let’s we show the result of the forward progress of neural network.
我們先來看NN(Neural Network)的前向過程:在這裏插入圖片描述
z(t1)z(t_1) 代表 t1t_1 時刻的隱藏狀態(hidden state),而當隱藏狀態被連續化後,t0t_0t1t_1 時刻的中間隱藏狀態的和就是等式中間部分的積分項。而整個前向過程可以用 ODE 求解器進行求解。注意,這裏並沒有定義 ff 的具體形式,一個需要考慮的問題是:ODE solver 是否可以求解任意形式的 ff。//todo

“The main technical difficulty in training continuous-depth networks is performing reverse-mode differentiation (also known as backpropagation) through the ODE solver.”

難點是使用 ODE solver 對連續的網絡求解其反向模型的微分形式

We treat the ODE solver as a black box, and compute gradients using the adjoint sensitivity method. This approach computes gradients by solving a second, augmented ODE backwards in time, and is applicable to all ODE solvers. "

這裏,將 ODE solver 看作是一個黑盒子,使用伴隨敏感方法來求解梯度。該方法通過求解第二個、增強了的時間向後(時間軸反向)的 ODE 來計算梯度,而且所有 ODE solvers 都適用。具體過程爲:

To optimize LL, we require gradients with respect to θ\theta. The first step is to determining how the gradient of the loss depends on the hidden state z(t)z(t) at each instant. This quantity is called the adjoint a(t)=Lz(t)a(t) =\frac{\partial L}{\partial z(t)}. Its dynamics are given by another ODE, which can be thought of as the instantaneous analog of the chain rule:

爲了優化損失 L, 需要計算它對 θ\theta 的導數。第一步是怎樣確定梯度依賴的隱層狀態 z(t)z(t). 該性質稱爲 伴隨。它的動態過程被另一個 ODE 來求解,可以把這種瞬時性被看作鏈式法則:
在這裏插入圖片描述(1)
該等式在1962年由 Pontryagin et al. 的論文《The mathematical theory of optimal processes》給出過證明,不過,本文作者也給出了相應的更簡潔的證明過程:
  對於連續的隱層狀態,可以將在時間上變化後的 ε\varepsilon 記作:
在這裏插入圖片描述(2)
上述公式說明,下一個狀態zz 是關於上一個狀態的函數(這裏將參數 θ\theta 看作常量,具體的積分值由 ff 決定)。 因此,相應的鏈式法則可以記作:
在這裏插入圖片描述(3)
由此,可以證明(1)式:
在這裏插入圖片描述
通過上述證明過程(引入 Tε(z(t))T_{\varepsilon}(z(t)) ,以說明 z(t+ε)z(t+\varepsilon)z(t)z(t)的函數),第二步用到等式(3),另外對等式(2)進行泰勒展開(TεT_{\varepsilon} 中的tt 被隱含了),注意展開過程中的無窮小參數同樣取 ε\varepsilon,然後就可以得到等式(1)。

  We specify the constraint on the last time point, which is simply the gradient of the loss wrt the last time point, and can obtain the gradients with respect to the hidden state at any time, including the initial value.

  這裏就可以看出 ODE 沿時間的反向過程和 NN 中反向傳播(BP)的相似性了。也就是通過 ODE 系統,前向和後向都是可以計算的。這裏假設(限制)最後時刻(TNT_N)的隱層狀態是已知的(可以直接通過 loss 的梯度獲取),就可以求解任意時刻的隱層狀態了(包括初始時刻):
在這裏插入圖片描述
  由此,整個 ODE 的反向過程的理論部分證明完成。

  這裏引入了一個伴隨狀態(Adjoint State),它和前向狀態相反,通過另一個 ODE 來求解。 關鍵是它們是怎樣建立聯繫的?見下圖:

在這裏插入圖片描述
  The adjoint sensitivity method solves an augmented ODE backwards in time. The augmented system contains both the original state and the sensitivity of the loss with respect to the state.
  伴隨敏感度方法使用一個增強的在時間上反向的 ODE。該增強系統同時包括 原來的狀態 a(t)a(t) 和損失對該狀態的敏感度 La(t)z(tN)\frac{\partial La(t)}{\partial {z(t_N)}}。具體它倆是怎麼計算的?
  答案是:由損失敏感度 La(t)z(tN)\frac{\partial La(t)}{\partial {z(t_N)}} 調節伴隨(adjoint)狀態 a(t)a(t), 然後再有伴隨狀態 a(t)a(t) 得到損失敏感度 La(t)z(tN)\frac{\partial La(t)}{\partial {z(t_N)}} 。這是 ODE 反向的鏈式過程。至此,整個反向傳播的過程就被模擬了!

  Computing the gradients with respect to the parameters θ requires evaluating a third integral, which depends on both z(t) and a(t):
  計算關於 θ\theta 的梯度,還要計算相關變量 z(t) and a(t) 的積分:
在這裏插入圖片描述(4)

  通過等式(1)和(4)就可以計算出梯度了,a(t)Tfz{a(t)}^T \frac{\partial f}{\partial z}a(t)Tfθ{a(t)}^T \frac{\partial f}{\partial \theta} 的vector-Jacobian products 都可以通過 ODE solver 快速求解。 所有的積分解:z,a,Lθz, a, \frac{\partial L}{\partial \theta} 都可以通過一個 ODE solver 來求解,可以將它們組合成一個向量解 (增強的狀態,augmented state)。具體步驟見算法 1:
在這裏插入圖片描述
該算法基本上是上述過程的綜合。首先定義初始狀態 s0s_0,然後定義 增強狀態,aug_dynamics,該狀態包括f(z(t),t,θ)f(z(t),t,\theta)a(t)Tfz{a(t)}^T \frac{\partial f}{\partial z}a(t)Tfθ{a(t)}^T \frac{\partial f}{\partial \theta} 的vector-Jacobian products(通過自動微分工具得到)。然後通過 ODE solver 求解前一時刻的隱層狀態,敏感狀態,和梯度。注意,這些都是合併起來的向量形式(算子形式的張量?)。最後,返回敏感狀態(用以下一時刻計算敏感狀態)和梯度(用以更新參數 θ\theta)。


Replacing residual networks with ODEs

將ResNets 換成 ODEs

Software: To solve ODE initial value problems numerically, we use the implicit Adams method implemented in LSODE and VODE and interfaced through the scipy.integrate package. Being an implicit method, it has better guarantees than explicit methods such as Runge-Kutta but requires solving a nonlinear optimization problem at every step.This setup makes direct backpropagation through the integrator difficult.

軟件實現: 爲了求解 ODE 的數值解, 作者使用 Adams (一種梯度優化方法)方法實現了 LSODE 和 VODE 的scipy.integrate 接口。 作爲一種隱式方法,它比顯式方法有較好的保證,如 Runge-Kutta 需要在每一步求解非線性優化問題。這種設置使得直接使用積分器求解反向傳播是困難的。作者使用 Python 的自動微分方法實現了伴隨敏感方法,並使用 Tensorflow 在GPU上實現了 隱層狀態的動態和求導(從Fortran ODE Solver 調用,從 Python autograd 中調用)。

Model Architectures: We experiment with a small residual network which downsamples the input twice then applies 6 standard residual blocks He et al. (2016b), which are replaced by an ODESolve module in the ODE-Net variant. We also test a network with the same architecture but where gradients are backpropagated directly through a Runge-Kutta integrator, referred to as RK-Net.

  論文中用兩個降採樣和6個殘差塊的小型 ResNet 進行了實驗,將殘差塊替換爲ODESolve 模塊就變成了 ODE-Net 變體。作者還使用相同的架構測試了使用 Runge-Kutta 積分器來反向傳播梯度的 RK-Net。

  代碼實現

首先看整體網絡結構:

feature_layers = [ODEBlock(ODEfunc(64))] if is_odenet else [ResBlock(64, 64) for _ in range(6)]

其中,ODEfunc 定義爲:

class ODEfunc(nn.Module):

    def __init__(self, dim):
        super(ODEfunc, self).__init__()
        self.norm1 = norm(dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm3 = norm(dim)
        self.nfe = 0 # number of forward ?

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(t, out)
        out = self.norm3(out)
        return out

  與 Residual Block 不同的是多加了一次 Batch Normalization,ODEfunc 中的卷積 ConcatConv2d 實現爲:

class ConcatConv2d(nn.Module):

    def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
        super(ConcatConv2d, self).__init__()
        module = nn.ConvTranspose2d if transpose else nn.Conv2d
        self._layer = module(
            dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
            bias=bias
        )

    def forward(self, t, x):
        tt = torch.ones_like(x[:, :1, :, :]) * t   # extract the first channel and multiply the time t.
        ttx = torch.cat([tt, x], 1)
        return self._layer(ttx)

  可以看到 ConcatConv2d 和原來的 卷積方式 基本相同,只是在 前向過程中,添加了變量(variable)tt , 其中,torch.ones_like 返回一個填充了標量值1的張量,其大小與之相同 input ,乘以 tt 表示在 tt 時刻。然後,將 ttxx 合併(concatenation)起來,然後作爲卷積的輸入。這裏有個問題,爲什麼變量 tt 的 size 是 feature size,難道是對每個feature position 做連續化?//TODO (這裏 grad 的形狀和feature size 的形狀相同)。

  接下來就是 ODEBlock的定義:

class ODEBlock(nn.Module):

    def __init__(self, odefunc):
        super(ODEBlock, self).__init__()
        self.odefunc = odefunc
        self.integration_time = torch.tensor([0, 1]).float()

    def forward(self, x):
        self.integration_time = self.integration_time.type_as(x)
        out = odeint(self.odefunc, x, self.integration_time, rtol=args.tol, atol=args.tol) # ODE forward
        return out[1]

  ODEBlock 中定義了 積分時間(integration time)t[0,1]t \in [0,1] ,然後在前向過程中傳入 odeint 中,關鍵點是 odeint, 按上述算法中。這裏 rtol 和 atol 是 容忍度(tolerance),即模型的精度設定。out[1] 是梯度(gradient)Lθ\frac{\partial L}{\partial \theta}。這樣我們求得了梯度。其中,odeint 的實現爲:

def odeint(func, y0, t, rtol=1e-7, atol=1e-9, method=None, options=None):
      tensor_input, func, y0, t = _check_inputs(func, y0, t)

    if options is None:
        options = {}
    elif method is None:
        raise ValueError('cannot supply `options` without specifying `method`')

    if method is None:
        method = 'dopri5'

    solver = SOLVERS[method](func, y0, rtol=rtol, atol=atol, **options)
    solution = solver.integrate(t)

    if tensor_input:
        solution = solution[0]
    return solution

The goal of an ODE solver is to find a continuous trajectory satisfying the ODE that passes through the initial condition. Solves the initial value problem (IVP) for a non-stiff system of first order ODEs: yt=f(t,y)\frac{\partial y}{\partial t}=f(t,y) s.t. y(t0)=y0y(t_0)=y_0 where y is a Tensor of any shape.

odeint 解的是非複雜(non-stiff)系統的一階 ODE 的初值問題 (IVP),其中,y是任意形狀的張量。以下是其中參數的解釋:

"""
	Args:
        func: Function that maps a Tensor holding the state `y` and a scalar Tensor
            `t` into a Tensor of state derivatives with respect to time.
    func:把一個含有狀態張量 y 和常張量 t 映射到 一個關於時間可導的張量上。
        y0: N-D Tensor giving starting value of `y` at time point `t[0]`. May
            have any floating point or complex dtype.
    y0: NxD維度的張量,是 y 在 t[0] 的初始點,可以是任意複雜的類型。
        t: 1-D Tensor holding a sequence of time points for which to solve for
            `y`. The initial time point should be the first element of this sequence,
            and each time must be larger than the previous time. May have any floating
            point dtype. Converted to a Tensor with float64 dtype.
    t: 1xD的張量,表示一系列用於求解 y 的時間點。
        rtol: optional float64 Tensor specifying an upper bound on relative error,
            per element of `y`.
    rtol: 相對錯誤容忍度,以限制張量 y 中每個元素的上限值。(可調節)
        atol: optional float64 Tensor specifying an upper bound on absolute error,
            per element of `y`.
    atol: 絕對錯誤容忍度,以限制張量 y 中每個元素的上限值。(可調節)
        method: optional string indicating the integration method to use.
        method: 可選的string型 以決定那種 積分方法 被使用。
        options: optional dict of configuring options for the indicated integration
            method. Can only be provided if a `method` is explicitly set.
    options: 可選的字典類型,用於配置積分方法。
        name: Optional name for this operation.
	name:  爲該操作指定名稱。
    Returns:
        y: Tensor, where the first dimension corresponds to different
            time points. Contains the solved value of y for each desired time point in
            `t`, with the initial value `y0` being the first element along the first
            dimension.
    Returns: 返回第一個維度對應不同的時間點的 y 張量。
             包含 y 在每個時間點 t 上被期望的解。(所有時間點的解都被求得了),
             初始值 y0 是第一維度的第一個元素。
"""

  看一下 SOLOVE中的積分方法:

SOLVERS = {
    'explicit_adams': AdamsBashforth,
    'fixed_adams': AdamsBashforthMoulton,
    'adams': VariableCoefficientAdamsBashforth,
    'tsit5': Tsit5Solver,
    'dopri5': Dopri5Solver,
    'euler': Euler,
    'midpoint': Midpoint,
    'rk4': RK4,
}

  這裏牽涉到微分方程的數值解法。這裏 AdamsBashforth、AdamsBashforthMoulton、Euler、Midpoint、RK4 (Fourth-order Runge-Kutta with 3/8 rule) 屬於 FixedGridODESolver (固定網格 ODE 求解器),其中,前兩個 Adams 類型的求解器 是作者自己實現的 Adam梯度下降方法來求解的 FixedGridODESolver。而VariableCoefficientAdamsBashforth、Tsit5Solver ()、Dopri5Solver (Runge-Kutta 4(5))屬於 AdaptiveStepsizeODESolver(自定義步長的 ODE 求解器)。論文中把 ODE solver 當作一個黑盒子(black box),我們知道它可以求解我們所需要的微分方程。這裏只看最簡單的 Euler 求解器:

class Euler(FixedGridODESolver):

    def step_func(self, func, t, dt, y):
        return tuple(dt * f_ for f_ in func(t, y))

  它只是實現了父類 FixedGridODESolver 中的 step_func,父類 FixedGridODESolver 的實現爲:

class FixedGridODESolver(object):
	def __init__(self, func, y0, step_size=None, grid_constructor=None, **unused_kwargs):
		...
		... # here, I omit some initialize progress in origin code
		# and omit some grid constructor progress.
		 
	@abc.abstractmethod
    def step_func(self, func, t, dt, y):
        pass

    def integrate(self, t):
        _assert_increasing(t) # t is increase sequence
        t = t.type_as(self.y0[0])
        time_grid = self.grid_constructor(self.func, self.y0, t) # grad
        assert time_grid[0] == t[0] and time_grid[-1] == t[-1]
        time_grid = time_grid.to(self.y0[0])

        solution = [self.y0] # target solution list

        j = 1
        y0 = self.y0
        for t0, t1 in zip(time_grid[:-1], time_grid[1:]):
            dy = self.step_func(self.func, t0, t1 - t0, y0) # use step function
            y1 = tuple(y0_ + dy_ for y0_, dy_ in zip(y0, dy)) # y1=y0+dy
            y0 = y1 # why to this?
			# linear interpolate the time sequence.
            while j < len(t) and t1 >= t[j]:
                solution.append(self._linear_interp(t0, t1, y0, y1, t[j]))
                j += 1

        return tuple(map(torch.stack, tuple(zip(*solution))))
        
    def _linear_interp(self, t0, t1, y0, y1, t):
        if t == t0:
            return y0
        if t == t1:
            return y1
        t0, t1, t = t0.to(y0[0]), t1.to(y0[0]), t.to(y0[0])
        slope = tuple((y1_ - y0_) / (t1 - t0) for y0_, y1_, in zip(y0, y1))
        return tuple(y0_ + slope_ * (t - t0) for y0_, slope_ in zip(y0, slope))

  這裏的積分 應該是對 差分 的積分,即根據初始值 y0y_0 和時間序列 tt 來求 yty_t。 首先構建 time grad,然後使用step_func,根據 func (NN 中的ff) 和 time grad 中的 t 以及 y0y_0 來計算 dydy, 接着,根據 y1=y0+dyy_1=y_0+dy 求得 y1y_1, 這裏有一行 y0=y1y_0=y_1, 爲什麼把y1賦值給 y0y_0 ? 然後再根據 y0y_0, y1y_1 求插值 ?這樣元素不就等於零了? //todo

  到這裏,整個 ODE-Net的方法和實現都走一遍了,但我們好像只看到了前向過程?沒有反向過程?這是因爲 反向過程被 Pytorch 在內部自動實現了 (autograd backpropagate),並沒有使用作者提出的 adjoint sensitivity method。作者指出使用 adjoint 方法可將 內存複雜度 降爲 O(1)O(1)

  Backpropagation through odeint goes through the internals of the solver, but this is not supported for all solvers. Instead, we encourage the use of the adjoint method, which will allow solving with as many steps as necessary due to O(1) memory usage.

  odeint_adjoint simply wraps around odeint, but will use only O(1) memory in exchange for solving an adjoint ODE in the backward call. The biggest gotcha is that func must be a nn.Module when using the adjoint method. This is used to collect parameters of the differential equation.

odeint_adjoint 簡單第封裝了 odeint,並實現了反向過程。但其最大的缺憾(硬傷)是func ff 的取值必須是 nn.Module 的方法,這是爲了收集微分方程的參數。( Why must be collect parameters of the differential equation? The answer is use to backward of adjoint odeint.)看一下adjoint odeint 的實現過程:

def odeint_adjoint(func, y0, t, rtol=1e-6, atol=1e-12, method=None, options=None):

    # We need this in order to access the variables inside this module,
    # since we have no other way of getting variables along the execution path.
    if not isinstance(func, nn.Module):
        raise ValueError('func is required to be an instance of nn.Module.')

    tensor_input = False
    if torch.is_tensor(y0):

        class TupleFunc(nn.Module):

            def __init__(self, base_func):
                super(TupleFunc, self).__init__()
                self.base_func = base_func

            def forward(self, t, y):
                return (self.base_func(t, y[0]),)

        tensor_input = True
        y0 = (y0,)
        func = TupleFunc(func)

    flat_params = _flatten(func.parameters())
    ys = OdeintAdjointMethod.apply(*y0, func, t, flat_params, rtol, atol, method, options)

    if tensor_input:
        ys = ys[0]
    return ys

  首先說明了odeint_adjoint 的變量是有序的,然後通過內部類封裝了一下 func,這裏明確的限制了 func 是 nn.Module,這樣 ODE-Net 的前向過程就實現了。接下來,通過 OdeintAdjointMethod 具體執行 ODE 的前向和反向過程:

class OdeintAdjointMethod(torch.autograd.Function):

    @staticmethod
    def forward(ctx, *args):
        assert len(args) >= 8, 'Internal error: all arguments required.'
        y0, func, t, flat_params, rtol, atol, method, options = \
            args[:-7], args[-7], args[-6], args[-5], args[-4], args[-3], args[-2], args[-1]

        ctx.func, ctx.rtol, ctx.atol, ctx.method, ctx.options = func, rtol, atol, method, options

        with torch.no_grad():
            ans = odeint(func, y0, t, rtol=rtol, atol=atol, method=method, options=options)
        ctx.save_for_backward(t, flat_params, *ans)
        return ans

  前向過程很簡單,通過繼承 torch.autograd.Function,將一些參數賦值給 ctx(沒有通過 self 實現,因爲ctx只在forward過程中存在。通過 self 會不會更直觀),並保存了 tt,func 的參數 和 odeint 的前向結果,以便在反向過程中使用。再看其反向過程:

    @staticmethod
    def backward(ctx, *grad_output):

        t, flat_params, *ans = ctx.saved_tensors
        ans = tuple(ans)
        func, rtol, atol, method, options = ctx.func, ctx.rtol, ctx.atol, ctx.method, ctx.options
        n_tensors = len(ans)
        f_params = tuple(func.parameters())

        # TODO: use a nn.Module and call odeint_adjoint to implement higher order derivatives.
        def augmented_dynamics(t, y_aug):
            # Dynamics of the original system augmented with
            # the adjoint wrt y, and an integrator wrt t and args.
            y, adj_y = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors]  # Ignore adj_time and adj_params.

            with torch.set_grad_enabled(True):
                t = t.to(y[0].device).detach().requires_grad_(True)
                y = tuple(y_.detach().requires_grad_(True) for y_ in y)
                func_eval = func(t, y)
                vjp_t, *vjp_y_and_params = torch.autograd.grad(
                    func_eval, (t,) + y + f_params,
                    tuple(-adj_y_ for adj_y_ in adj_y), allow_unused=True, retain_graph=True
                )
            vjp_y = vjp_y_and_params[:n_tensors]
            vjp_params = vjp_y_and_params[n_tensors:]

            # autograd.grad returns None if no gradient, set to zero.
            vjp_t = torch.zeros_like(t) if vjp_t is None else vjp_t
            vjp_y = tuple(torch.zeros_like(y_) if vjp_y_ is None else vjp_y_ for vjp_y_, y_ in zip(vjp_y, y))
            vjp_params = _flatten_convert_none_to_zeros(vjp_params, f_params)

            if len(f_params) == 0:
                vjp_params = torch.tensor(0.).to(vjp_y[0])
            return (*func_eval, *vjp_y, vjp_t, vjp_params)

        T = ans[0].shape[0]
        with torch.no_grad():
            adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output)
            adj_params = torch.zeros_like(flat_params)
            adj_time = torch.tensor(0.).to(t)
            time_vjps = []
            for i in range(T - 1, 0, -1):

                ans_i = tuple(ans_[i] for ans_ in ans)
                grad_output_i = tuple(grad_output_[i] for grad_output_ in grad_output)
                func_i = func(t[i], ans_i)

                # Compute the effect of moving the current time measurement point.
                dLd_cur_t = sum(
                    torch.dot(func_i_.reshape(-1), grad_output_i_.reshape(-1)).reshape(1)
                    for func_i_, grad_output_i_ in zip(func_i, grad_output_i)
                )
                adj_time = adj_time - dLd_cur_t
                time_vjps.append(dLd_cur_t)

                # Run the augmented system backwards in time.
                if adj_params.numel() == 0:
                    adj_params = torch.tensor(0.).to(adj_y[0])
                aug_y0 = (*ans_i, *adj_y, adj_time, adj_params)
                aug_ans = odeint(
                    augmented_dynamics, aug_y0,
                    torch.tensor([t[i], t[i - 1]]), rtol=rtol, atol=atol, method=method, options=options
                )

                # Unpack aug_ans.
                adj_y = aug_ans[n_tensors:2 * n_tensors]
                adj_time = aug_ans[2 * n_tensors]
                adj_params = aug_ans[2 * n_tensors + 1]

                adj_y = tuple(adj_y_[1] if len(adj_y_) > 0 else adj_y_ for adj_y_ in adj_y)
                if len(adj_time) > 0: adj_time = adj_time[1]
                if len(adj_params) > 0: adj_params = adj_params[1]

                adj_y = tuple(adj_y_ + grad_output_[i - 1] for adj_y_, grad_output_ in zip(adj_y, grad_output))

                del aug_y0, aug_ans

            time_vjps.append(adj_time)
            time_vjps = torch.cat(time_vjps[::-1])

            return (*adj_y, None, time_vjps, adj_params, None, None, None, None, None)

  其中,torch.autograd.grad(outputs, inputs, grad_outputs=None, … ) 是用來計算輸出對輸入的梯度(Computes and returns the sum of gradients of outputs w.r.t. the inputs.)。這裏需要用到 自動微分 中的知識。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章