入門神經網絡優化算法(五):一文看懂二階優化算法Natural Gradient Descent(Fisher Information)

歡迎查看我的博客文章合集:我的Blog文章索引::機器學習方法系列,深度學習方法系列,三十分鐘理解系列等

這個系列會有多篇神經網絡優化方法的複習/學習筆記,主要是一些優化器。目前有計劃的包括:

  • 入門神經網絡優化算法(一):Gradient Descent,Momentum,Nesterov accelerated gradient
  • 入門神經網絡優化算法(二):Adaptive Optimization Methods:Adagrad,RMSprop,Adam
  • 入門神經網絡優化算法(三):待定
  • 入門神經網絡優化算法(四):AMSGrad,Radam等一些Adam變種
  • 入門神經網絡優化算法(五):二階優化算法Natural Gradient Descent(Fisher Information)
  • 入門神經網絡優化算法(六):二階優化算法K-FAC
  • 入門神經網絡優化算法(七):二階優化算法Shampoo

二階優化算法Natural Gradient Descent,是從分佈空間推導最速梯度下降方向的方法,和牛頓方法有非常緊密的聯繫。Fisher Information Matrix往往可以用來代替牛頓法的Hessian矩陣計算。下面詳細道來。

1. Fisher Information Matrix

瞭解Natural Gradient Descent方法,需要先了解Fisher Information Matrix的定義。參考資料主要有[1][2],加上我自己的理解。

1.1 Score function

假設我們有一個模型參數向量是θ\theta,似然函數一般表示成p(xθ)p(x | \theta)。在很多算法中,我們經常需要學習參數θ\theta以最大化似然函數(likelihood)p(xθ)p(x | \theta)。這個時候,定義Score function s(θ)s(\theta),the gradient of log likelihood function:
s(θ)=θlogp(xθ) s(\theta) = \nabla_{\theta} \log p(x \vert \theta) \\

這個Score function在很多地方都要用到,特別的,在強化學習Policy Gradient類方法中,我們會直接用到Score function求參數梯度來更新policy參數。

Score function的性質:The expected value of score function wrt. the model is zero.

證明:
Ep(xθ)[s(θ)]=Ep(xθ)[logp(xθ)]=logp(xθ)p(xθ)dx=1p(xθ)p(xθ)p(xθ)dx=p(xθ)dx=p(xθ)dx=1=0 \mathop{\mathbb{E}}_{p(x \vert \theta)} \left[ s(\theta) \right] = \mathop{\mathbb{E}}_{p(x \vert \theta)} \left[ \nabla \log p(x \vert \theta) \right] \\[5pt] = \int \nabla \log p(x \vert \theta) \, p(x \vert \theta) \, \text{d}x \\[5pt] = \int \frac{1}{p(x \vert \theta)} \nabla p(x \vert \theta) p(x \vert \theta) \text{d}x \\[5pt] = \int \nabla p(x \vert \theta) \, \text{d}x \\[5pt] = \nabla \int p(x \vert \theta) \, \text{d}x \\[5pt] = \nabla 1 \\[5pt] = 0

1.2 Fisher Information

雖然期望爲零,但是我們需要評估Score function的不確定性,我們採用協方差矩陣的期望(針對模型本身):
Ep(xθ)[(s(θ)0)(s(θ)0)T] \mathop{\mathbb{E}}_{p(x \vert \theta)} \left[ (s(\theta) - 0) \, (s(\theta) - 0)^{\text{T}} \right]
上述定義(協方差矩陣的期望,針對model p(xθ)p(x \vert \theta) )稱之爲Fisher Information,如果θ\theta是表示成一個列向量,那麼Score function也是一個列向量,而Fisher Information是一個矩陣形式,我們稱之爲Fisher Information Matrix

F=Ep(xθ)[logp(xθ)logp(xθ)T] \text{F} = \mathop{\mathbb{E}}_{p(x \vert \theta)} \left[ \nabla \log p(x \vert \theta) \, \nabla \log p(x \vert \theta)^{\text{T}} \right]

但是呢,往往p(xθ)p(x \vert \theta) 形式是比較複雜的,甚至是一個模型的輸出,要計算期望是不太可能的。因此,實際上我們用的比較多的情況是,採用training data X={x1,x2,,xN}X = \{ x_1, x_2, \cdots, x_N \}計算得到的Empirical Fisher:
F=1Ni=1Nlogp(xiθ)logp(xiθ)T \text{F} = \frac{1}{N} \sum_{i=1}^{N} \nabla \log p(x_i \vert \theta) \, \nabla \log p(x_i \vert \theta)^{\text{T}}

1.3 Fisher矩陣和Hessian矩陣的關係

前面是背景介紹,下面進入正題,Fisher矩陣和Hessian矩陣的關係。可以證明:log似然函數的海森矩陣的期望的負數,等於Fisher Information Matrix.

Claim: The negative expected Hessian of log likelihood is equal to the Fisher Information Matrix F

證明:核心思想是,The Hessian of the log likelihood is given by the Jacobian of its gradient:
Hlogp(xθ)=J[p(xθ)p(xθ)]=Hp(xθ)p(xθ)p(xθ)p(xθ)Tp(xθ)p(xθ)=Hp(xθ)p(xθ)p(xθ)p(xθ)p(xθ)p(xθ)Tp(xθ)p(xθ)=Hp(xθ)p(xθ)(p(xθ)p(xθ))(p(xθ)p(xθ))T \text{H}_{\log p(x \vert \theta)} = \text{J} \left[\frac{\nabla p(x \vert \theta)}{p(x \vert \theta)}\right] \\[8pt] = \frac{ \text{H}_{p(x \vert \theta)} \, p(x \vert \theta) - \nabla p(x \vert \theta) \, \nabla p(x \vert \theta)^{\text{T}}}{p(x \vert \theta) \, p(x \vert \theta)} \\[8pt] = \frac{\text{H}_{p(x \vert \theta)} \, p(x \vert \theta)}{p(x \vert \theta) \, p(x \vert \theta)} - \frac{\nabla p(x \vert \theta) \, \nabla p(x \vert \theta)^{\text{T}}}{p(x \vert \theta) \, p(x \vert \theta)} \\[8pt] = \frac{\text{H}_{p(x \vert \theta)}}{p(x \vert \theta)} - \left( \frac{\nabla p(x \vert \theta)}{p(x \vert \theta)} \right) \left( \frac{\nabla p(x \vert \theta)}{p(x \vert \theta)}\right)^{\text{T}}

推導的時候主要注意,p(xθ)p(x \vert \theta)是一個標量;而p(xθ)\nabla p(x \vert \theta)是對參數的梯度,是一個列向量。
然後Taking expectation wrt. the model, we have:

Ep(xθ)[Hlogp(xθ)]=Ep(xθ)[Hp(xθ)p(xθ)(p(xθ)p(xθ))(p(xθ)p(xθ))T]=Ep(xθ)[Hp(xθ)p(xθ)]Ep(xθ)[(p(xθ)p(xθ))(p(xθ)p(xθ))T]=Hp(xθ)p(xθ)p(xθ)dxEp(xθ)[logp(xθ)logp(xθ)T]=Hp(xθ)dxF=H1F=F. \mathop{\mathbb{E}}_{p(x \vert \theta)} \left[ \text{H}_{\log p(x \vert \theta)} \right] = \mathop{\mathbb{E}}_{p(x \vert \theta)} \left[ \frac{\text{H}_{p(x \vert \theta)}}{p(x \vert \theta)} - \left( \frac{\nabla p(x \vert \theta)}{p(x \vert \theta)} \right) \left( \frac{\nabla p(x \vert \theta)}{p(x \vert \theta)} \right)^{\text{T}} \right] \\[5pt] = \mathop{\mathbb{E}}_{p(x \vert \theta)} \left[ \frac{\text{H}_{p(x \vert \theta)}}{p(x \vert \theta)} \right] - \mathop{\mathbb{E}}_{p(x \vert \theta)} \left[ \left( \frac{\nabla p(x \vert \theta)}{p(x \vert \theta)} \right) \left( \frac{\nabla p(x \vert \theta)}{p(x \vert \theta)}\right)^{\text{T}} \right] \\[5pt] = \int \frac{\text{H}_{p(x \vert \theta)}}{p(x \vert \theta)} p(x \vert \theta) \, \text{d}x \, - \mathop{\mathbb{E}}_{p(x \vert \theta)} \left[ \nabla \log p(x \vert \theta) \, \nabla \log p(x \vert \theta)^{\text{T}} \right] \\[5pt] = \text{H}_{\int p(x \vert \theta) \, \text{d}x} \, - \text{F} \\[5pt] = \text{H}_{1} - \text{F} \\[5pt] = -\text{F} \, .

因此我們得到了:F=Ep(xθ)[Hlogp(xθ)]\text{F} = -\mathop{\mathbb{E}}_{p(x \vert \theta)} \left[ \text{H}_{\log p(x \vert \theta)} \right],證明完畢。我們可以將F的作用看作是對數似然函數曲率的度量。一種很自然的想法就是,在二階優化算法中,比如牛頓法中,需要計算Hessian矩陣,那麼是否可以用Fisher矩陣來代替Hessian舉證呢?這就引出了下面要講的natural gradient方法了。

2. 自然梯度下降法Natural Gradient Descent

先來講一講parameter space和distribution space的概念,導致了對梯度下降的不同理解。

  • parameter space:一般我們解決優化問題最常用的方法是用梯度下降,每一步優化方向採用負梯度方向,θL(θ)-\nabla_\theta \mathcal{L}(\theta)。可以知道,負梯度方向是在當前的參數值θ\theta的local neighborhood裏loss在參數空間的最速下降方向。
    θL(θ)θL(θ)=limϵ01ϵarg mind s.t. dϵL(θ+d). \frac{-\nabla_\theta \mathcal{L}(\theta)}{\lVert \nabla_\theta \mathcal{L}(\theta) \rVert} = \lim_{\epsilon \to 0} \frac{1}{\epsilon} \mathop{\text{arg min}}_{d \text{ s.t. } \lVert d \rVert \leq \epsilon} \mathcal{L}(\theta + d) \, .
    上面的表達式是,參數空間中最陡的下降方向是選取一個向量dd,使得新參數θ+d\theta+d在當前參數θ\thetaϵ\epsilon-鄰域內,並且我們選取使損失最小的dd。注意我們用歐幾里德範數來表示這個鄰域。因此,梯度下降的優化依賴於參數空間的歐氏幾何度量。

  • distribution space:同時,如果我們的目標是最小化損失函數(最大化似然),那麼我們自然會在所有可能的似然空間中採取優化步驟,通過參數θ\theta來實現。由於似然函數本身是一個概率分佈,我們稱它所在的空間爲分佈空間(distribution space)。因此,在分佈空間中採用最陡下降方向,而不是參數空間,是有道理的。

在distribution space中,用什麼距離度量呢?常用的選擇就是用KL散度(KL-divergence),KL散度常用語評估兩個分佈的接近程度。但是,實際上KL散度是不對稱的,因此理論上不是一個distance metric,但是呢,很多地方還是用KL散度來衡量兩個分佈的接近程度。(as dd goes to zero, KL-divergence is asymptotically symmetric. So, within a local neighbourhood, KL-divergence is approximately symmetric [3].)

2.1 分佈空間中的最速下降,Natural gradient方法

前面講了那麼多,終於要引出自然梯度方法的基本推導了。

先推導KL散度的泰勒展開有如下形式:
KL[p(xθ)p(xθ+d)]12dTFd\text{KL}[p(x \vert \theta) \, \Vert \, p(x \vert \theta + d)] \approx \frac{1}{2} d^\text{T} \text{F} d

證明:寫出二階泰勒展開:

KL[p(xθ)p(xθ+d)]KL[p(xθ)p(xθ)]θ=θ+(θKL[p(xθ)p(xθ)]θ=θ)Td+12dTθ2KL[p(xθ)p(xθ)]θ=θd=KL[p(xθ)p(xθ)]Ep(xθ)[θlogp(xθ)]Td+12dTFd=12dTFd \text{KL}[p(x \vert \theta) \, \Vert \, p(x \vert \theta+d)] \approx \text{KL}[p(x \vert \theta) \, \Vert \, p(x \vert \theta')]\vert_{\theta' = \theta} + (\left. \nabla_{\theta'} \text{KL}[p(x \vert \theta) \, \Vert \, p(x \vert \theta')] \right\vert_{\theta' = \theta})^\text{T} d + \frac{1}{2} d^\text{T} \nabla_{\theta'}^2 \, \text{KL}[p(x \vert \theta) \, \Vert \, p(x \vert \theta')]\vert_{\theta' = \theta}d \\[5pt] =\text{KL}[p(x \vert \theta) \, \Vert \, p(x \vert \theta)] - \mathop{\mathbb{E}}_{p(x \vert \theta)} [ \nabla_\theta \log p(x \vert \theta) ]^\text{T} d + \frac{1}{2} d^\text{T} \text{F} d = \frac{1}{2} d^\text{T} \text{F} d\\[5pt]

這樣理解爲什麼引入θ\theta':把KL散度第一個p(xθ)p(x \vert \theta)看成一個確定的分佈,而變化的是在第二個分佈的參數上。我們依次來看下約等號\approx後面這三項:

  • 泰勒展開的第一項 KL[pθpθ]=0\text{KL}[p_{\theta} \, \Vert \, p_{\theta}] = 0

  • 第二項的推導:
    θKL[p(xθ)p(xθ)]=θEp(xθ)[logp(xθ)]θEp(xθ)[logp(xθ)]=Ep(xθ)[θlogp(xθ)]=0 \nabla_{\theta'} \text{KL}[p(x \vert \theta) \, \Vert \, p(x \vert \theta')] = \nabla_{\theta'} \mathop{\mathbb{E}}_{p(x \vert \theta)} [ \log p(x \vert \theta) ] - \nabla_{\theta'} \mathop{\mathbb{E}}_{p(x \vert \theta)} [ \log p(x \vert \theta') ] \\[8pt] = - \mathop{\mathbb{E}}_{p(x \vert \theta)} [ \nabla_{\theta'} \log p(x \vert \theta') ] =0\\[5pt]
    考慮θ=θ\vert_{\theta' = \theta}的話,第二項包含了Score function的期望。正好是本章節前面Fisher Matrix部分講過的,Score function的期望,已經證明過是0。

  • 第三項,需要用到前面第一章證明過的,F=Ep(xθ)[Hlogp(xθ)]\text{F} = -\mathop{\mathbb{E}}_{p(x \vert \theta)} \left[ \text{H}_{\log p(x \vert \theta)} \right],以及如下性質:Fisher Information Matrix F is the Hessian of KL-divergence between two distributions p(xθ)p(x \vert \theta) and p(xθ)p(x \vert \theta'), with respect to θ\theta', evaluated at θ=θ\theta' = \theta,下面是推導過程:
    KL[p(xθ)p(xθ)]=Ep(xθ)[logp(xθ)]Ep(xθ)[logp(xθ)] \text{KL} [p(x \vert \theta) \, \Vert \, p(x \vert \theta')] = \mathop{\mathbb{E}}_{p(x \vert \theta)} [ \log p(x \vert \theta) ] - \mathop{\mathbb{E}}_{p(x \vert \theta)} [ \log p(x \vert \theta') ]
    The first derivative wrt. θ\theta' is:
    θKL[p(xθ)p(xθ)]=θEp(xθ)[logp(xθ)]θEp(xθ)[logp(xθ)]=Ep(xθ)[θlogp(xθ)]=p(xθ)θlogp(xθ)dx \nabla_{\theta'} \text{KL}[p(x \vert \theta) \, \Vert \, p(x \vert \theta')] = \nabla_{\theta'} \mathop{\mathbb{E}}_{p(x \vert \theta)} [ \log p(x \vert \theta) ] - \nabla_{\theta'} \mathop{\mathbb{E}}_{p(x \vert \theta)} [ \log p(x \vert \theta') ] \\[5pt] = - \mathop{\mathbb{E}}_{p(x \vert \theta)} [ \nabla_{\theta'} \log p(x \vert \theta') ] \\[5pt] = - \int p(x \vert \theta) \nabla_{\theta'} \log p(x \vert \theta') \, \text{d}x
    The second derivative is:
    θ2KL[p(xθ)p(xθ)]θ=θ=p(xθ)θ2logp(xθ)θ=θdx=p(xθ)Hlogp(xθ)dx=Ep(xθ)[Hlogp(xθ)]=F \nabla_{\theta'}^2 \, \text{KL}[p(x \vert \theta) \, \Vert \, p(x \vert \theta')]\vert_{\theta' = \theta} = - \int p(x \vert \theta) \, \nabla_{\theta'}^2 \log p(x \vert \theta')\vert_{\theta' = \theta} \, \text{d}x \\[5pt] = - \int p(x \vert \theta) \, \text{H}_{\log p(x \vert \theta)} \, \text{d}x \\[5pt] = - \mathop{\mathbb{E}}_{p(x \vert \theta)} [\text{H}_{\log p(x \vert \theta)}] \\[5pt] = \text{F}

所以得到KL散度的二階泰勒展開形式:
KL[p(xθ)p(xθ+d)]12dTFd\text{KL}[p(x \vert \theta) \, \Vert \, p(x \vert \theta + d)] \approx \frac{1}{2} d^\text{T} \text{F} d

現在,我們想知道什麼是使分佈空間中的損失函數L最小化的更新向量d,以便我們知道哪個方向的KL散度減小得最多。這類似於最速下降法,但在以KL散度爲度量的分佈空間,而不是通常的以歐氏度量的參數空間。爲此,我們將最小化:

d=arg mind s.t. KL[pθpθ+d]cL(θ+d), d^* = \mathop{\text{arg min}}_{d \text{ s.t. } \text{KL}[p_\theta \Vert p_{\theta + d}] \leq c} \mathcal{L} (\theta + d) \, ,

如果我們寫出上面的最小化問題在拉格朗日乘子法形式,用二階泰勒展開近似KL散度,用一階泰勒級數展開近似L\mathcal{L}

d=arg mindL(θ+d)+λ(KL[pθpθ+d]c)arg mindL(θ)+θL(θ)Td+12λdTFdλc d^* = \mathop{\text{arg min}}_d \, \mathcal{L} (\theta + d) + \lambda \, (\text{KL}[p_\theta \Vert p_{\theta + d}] - c) \\[8pt] \approx \mathop{\text{arg min}}_d \, \mathcal{L}(\theta) + \nabla_\theta \mathcal{L}(\theta)^\text{T} d + \frac{1}{2} \lambda \, d^\text{T} \text{F} d - \lambda c
其中λ\lambda是拉格朗日系數,要求解這個優化問題,我們求dd的梯度等於0:
0=d[L(θ)+θL(θ)Td+12λdTFdλc]=θL(θ)+λFdλFd=θL(θ)d=1λF1θL(θ) 0 = \frac{\partial}{\partial d} \left[\mathcal{L}(\theta) + \nabla_\theta \mathcal{L}(\theta)^\text{T} d + \frac{1}{2} \lambda \, d^\text{T} \text{F} d - \lambda c\right] \\[8pt] = \nabla_\theta \mathcal{L}(\theta) + \lambda \, \text{F} d \\[8pt] \lambda \, \text{F} d = -\nabla_\theta \mathcal{L}(\theta) \\[8pt] d = -\frac{1}{\lambda} \text{F}^{-1} \nabla_\theta \mathcal{L}(\theta) \\[8pt]

因此,先不看1λ\frac{1}{\lambda}(可以一起考慮吸收到learning rate部分),我們得到在分佈空間中,最優的更新方向是F1θL(θ)-\text{F}^{-1} \nabla_\theta \mathcal{L}(\theta)。(類比二階優化方法的牛頓法,更新方向是H1θL(θ)-\text{H}^{-1} \nabla_\theta \mathcal{L}(\theta),非常類似吧)。

我們把Natural gradient 定義成:~θL(θ)=F1θL(θ)\tilde{\nabla}_\theta \mathcal{L}(\theta) = \text{F}^{-1} \nabla_\theta \mathcal{L}(\theta). 自然梯度下降算法的基本流程如下:(一般我們會採用batch模式的Empirical Fisher Matrix:F=1Ni=1Nlogp(xiθ)logp(xiθ)T\text{F} = \frac{1}{N} \sum_{i=1}^{N} \nabla \log p(x_i \vert \theta) \, \nabla \log p(x_i \vert \theta)^{\text{T}}
在這裏插入圖片描述

與Adam關係的類比討論

在數據量較少的非常簡單的模型中,我們看到可以很容易地實現自然梯度下降。但衆所周知,深度學習模型中的參數數目非常大,千萬甚至億級參數量模型很常見,即使一層都有上百萬參數。這類模型的Fisher信息矩陣難以計算、存儲、以及求逆。這和二階優化方法在深度學習中不受歡迎的原因是一樣的。

解決這個問題的一種方法是計算近似的Fisher/Hessian。像ADAM[5]這樣的方法計算梯度的一階和二階moving average(m和v)。m是動量momentum,這裏不討論。而v可以看成是Fisher信息矩陣的近似——但將其約束爲對角矩陣(協方差的對角線元素是梯度的平方)。因此,在ADAM中,我們只需要O(n)O(n)空間來存儲(F的近似值)而不是O(n2)O(n^2),並且可以在O(n)O(n)而不是O(n3)O(n^3)中進行求逆運算。在實踐中,ADAM工作得非常好,是目前優化深層神經網絡的基準優化方法。

在這裏插入圖片描述

OK,這一篇終於基本寫好了,後面會繼續這個話題,再記錄一下如何加速自然梯度方法的工作,主要是比較知名的K-FAC算法。這篇可能還有一些關於自然梯度的引申討論,過幾天再補。參考[6][7]。TBD…

參考資料

[1] https://wiseodd.github.io/techblog/2018/03/11/fisher-information/
[2] https://wiseodd.github.io/techblog/2018/03/14/natural-gradient/
[3] Martens, James. “New insights and perspectives on the natural gradient method.” arXiv preprint arXiv:1412.1193 (2014).
[4] Ly, Alexander, et al. “A tutorial on Fisher information.” Journal of Mathematical Psychology 80 (2017): 40-55
[5] ADAM A METHOD FOR STOCHASTIC OPTIMIZATION. 2015
[6] 多角度理解自然梯度,https://zhuanlan.zhihu.com/p/82934100
[7] 如何理解 natural gradient descent?,https://www.zhihu.com/question/266846405

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章