歡迎查看我的博客文章合集:我的Blog文章索引::機器學習方法系列,深度學習方法系列,三十分鐘理解系列等
這個系列會有多篇神經網絡優化方法的複習/學習筆記,主要是一些優化器。目前有計劃的包括:
二階優化算法Natural Gradient Descent,是從分佈空間推導最速梯度下降方向的方法,和牛頓方法有非常緊密的聯繫。Fisher Information Matrix往往可以用來代替牛頓法的Hessian矩陣計算。下面詳細道來。
1. Fisher Information Matrix
瞭解Natural Gradient Descent方法,需要先了解Fisher Information Matrix的定義。參考資料主要有[1][2],加上我自己的理解。
1.1 Score function
假設我們有一個模型參數向量是θ \theta θ ,似然函數一般表示成p ( x ∣ θ ) p(x | \theta) p ( x ∣ θ ) 。在很多算法中,我們經常需要學習參數θ \theta θ 以最大化似然函數(likelihood)p ( x ∣ θ ) p(x | \theta) p ( x ∣ θ ) 。這個時候,定義Score function s ( θ ) s(\theta) s ( θ ) ,the gradient of log likelihood function:
s ( θ ) = ∇ θ log p ( x ∣ θ )
s(\theta) = \nabla_{\theta} \log p(x \vert \theta) \\
s ( θ ) = ∇ θ log p ( x ∣ θ )
這個Score function在很多地方都要用到,特別的,在強化學習Policy Gradient類方法中,我們會直接用到Score function求參數梯度來更新policy參數。
Score function的性質:The expected value of score function wrt. the model is zero.
證明:
E p ( x ∣ θ ) [ s ( θ ) ] = E p ( x ∣ θ ) [ ∇ log p ( x ∣ θ ) ] = ∫ ∇ log p ( x ∣ θ ) p ( x ∣ θ ) d x = ∫ 1 p ( x ∣ θ ) ∇ p ( x ∣ θ ) p ( x ∣ θ ) d x = ∫ ∇ p ( x ∣ θ ) d x = ∇ ∫ p ( x ∣ θ ) d x = ∇ 1 = 0
\mathop{\mathbb{E}}_{p(x \vert \theta)} \left[ s(\theta) \right] = \mathop{\mathbb{E}}_{p(x \vert \theta)} \left[ \nabla \log p(x \vert \theta) \right] \\[5pt]
= \int \nabla \log p(x \vert \theta) \, p(x \vert \theta) \, \text{d}x \\[5pt]
= \int \frac{1}{p(x \vert \theta)} \nabla p(x \vert \theta) p(x \vert \theta) \text{d}x \\[5pt]
= \int \nabla p(x \vert \theta) \, \text{d}x \\[5pt]
= \nabla \int p(x \vert \theta) \, \text{d}x \\[5pt]
= \nabla 1 \\[5pt]
= 0
E p ( x ∣ θ ) [ s ( θ ) ] = E p ( x ∣ θ ) [ ∇ log p ( x ∣ θ ) ] = ∫ ∇ log p ( x ∣ θ ) p ( x ∣ θ ) d x = ∫ p ( x ∣ θ ) 1 ∇ p ( x ∣ θ ) p ( x ∣ θ ) d x = ∫ ∇ p ( x ∣ θ ) d x = ∇ ∫ p ( x ∣ θ ) d x = ∇ 1 = 0
1.2 Fisher Information
雖然期望爲零,但是我們需要評估Score function的不確定性,我們採用協方差矩陣的期望(針對模型本身):
E p ( x ∣ θ ) [ ( s ( θ ) − 0 ) ( s ( θ ) − 0 ) T ]
\mathop{\mathbb{E}}_{p(x \vert \theta)} \left[ (s(\theta) - 0) \, (s(\theta) - 0)^{\text{T}} \right]
E p ( x ∣ θ ) [ ( s ( θ ) − 0 ) ( s ( θ ) − 0 ) T ]
上述定義(協方差矩陣的期望,針對model p ( x ∣ θ ) p(x \vert \theta) p ( x ∣ θ ) )稱之爲Fisher Information ,如果θ \theta θ 是表示成一個列向量,那麼Score function也是一個列向量,而Fisher Information是一個矩陣形式,我們稱之爲Fisher Information Matrix 。
F = E p ( x ∣ θ ) [ ∇ log p ( x ∣ θ ) ∇ log p ( x ∣ θ ) T ]
\text{F} = \mathop{\mathbb{E}}_{p(x \vert \theta)} \left[ \nabla \log p(x \vert \theta) \, \nabla \log p(x \vert \theta)^{\text{T}} \right]
F = E p ( x ∣ θ ) [ ∇ log p ( x ∣ θ ) ∇ log p ( x ∣ θ ) T ]
但是呢,往往p ( x ∣ θ ) p(x \vert \theta) p ( x ∣ θ ) 形式是比較複雜的,甚至是一個模型的輸出,要計算期望是不太可能的。因此,實際上我們用的比較多的情況是,採用training data X = { x 1 , x 2 , ⋯ , x N } X = \{ x_1, x_2, \cdots, x_N \} X = { x 1 , x 2 , ⋯ , x N } 計算得到的Empirical Fisher:
F = 1 N ∑ i = 1 N ∇ log p ( x i ∣ θ ) ∇ log p ( x i ∣ θ ) T
\text{F} = \frac{1}{N} \sum_{i=1}^{N} \nabla \log p(x_i \vert \theta) \, \nabla \log p(x_i \vert \theta)^{\text{T}}
F = N 1 i = 1 ∑ N ∇ log p ( x i ∣ θ ) ∇ log p ( x i ∣ θ ) T
1.3 Fisher矩陣和Hessian矩陣的關係
前面是背景介紹,下面進入正題,Fisher矩陣和Hessian矩陣的關係。可以證明:log似然函數的海森矩陣的期望的負數,等於Fisher Information Matrix.
Claim: The negative expected Hessian of log likelihood is equal to the Fisher Information Matrix F
證明:核心思想是,The Hessian of the log likelihood is given by the Jacobian of its gradient:
H log p ( x ∣ θ ) = J [ ∇ p ( x ∣ θ ) p ( x ∣ θ ) ] = H p ( x ∣ θ ) p ( x ∣ θ ) − ∇ p ( x ∣ θ ) ∇ p ( x ∣ θ ) T p ( x ∣ θ ) p ( x ∣ θ ) = H p ( x ∣ θ ) p ( x ∣ θ ) p ( x ∣ θ ) p ( x ∣ θ ) − ∇ p ( x ∣ θ ) ∇ p ( x ∣ θ ) T p ( x ∣ θ ) p ( x ∣ θ ) = H p ( x ∣ θ ) p ( x ∣ θ ) − ( ∇ p ( x ∣ θ ) p ( x ∣ θ ) ) ( ∇ p ( x ∣ θ ) p ( x ∣ θ ) ) T
\text{H}_{\log p(x \vert \theta)} = \text{J} \left[\frac{\nabla p(x \vert \theta)}{p(x \vert \theta)}\right] \\[8pt]
= \frac{ \text{H}_{p(x \vert \theta)} \, p(x \vert \theta) - \nabla p(x \vert \theta) \, \nabla p(x \vert \theta)^{\text{T}}}{p(x \vert \theta) \, p(x \vert \theta)} \\[8pt]
= \frac{\text{H}_{p(x \vert \theta)} \, p(x \vert \theta)}{p(x \vert \theta) \, p(x \vert \theta)} - \frac{\nabla p(x \vert \theta) \, \nabla p(x \vert \theta)^{\text{T}}}{p(x \vert \theta) \, p(x \vert \theta)} \\[8pt]
= \frac{\text{H}_{p(x \vert \theta)}}{p(x \vert \theta)} - \left( \frac{\nabla p(x \vert \theta)}{p(x \vert \theta)} \right) \left( \frac{\nabla p(x \vert \theta)}{p(x \vert \theta)}\right)^{\text{T}}
H log p ( x ∣ θ ) = J [ p ( x ∣ θ ) ∇ p ( x ∣ θ ) ] = p ( x ∣ θ ) p ( x ∣ θ ) H p ( x ∣ θ ) p ( x ∣ θ ) − ∇ p ( x ∣ θ ) ∇ p ( x ∣ θ ) T = p ( x ∣ θ ) p ( x ∣ θ ) H p ( x ∣ θ ) p ( x ∣ θ ) − p ( x ∣ θ ) p ( x ∣ θ ) ∇ p ( x ∣ θ ) ∇ p ( x ∣ θ ) T = p ( x ∣ θ ) H p ( x ∣ θ ) − ( p ( x ∣ θ ) ∇ p ( x ∣ θ ) ) ( p ( x ∣ θ ) ∇ p ( x ∣ θ ) ) T
推導的時候主要注意,p ( x ∣ θ ) p(x \vert \theta) p ( x ∣ θ ) 是一個標量;而∇ p ( x ∣ θ ) \nabla p(x \vert \theta) ∇ p ( x ∣ θ ) 是對參數的梯度,是一個列向量。
然後Taking expectation wrt. the model, we have:
E p ( x ∣ θ ) [ H log p ( x ∣ θ ) ] = E p ( x ∣ θ ) [ H p ( x ∣ θ ) p ( x ∣ θ ) − ( ∇ p ( x ∣ θ ) p ( x ∣ θ ) ) ( ∇ p ( x ∣ θ ) p ( x ∣ θ ) ) T ] = E p ( x ∣ θ ) [ H p ( x ∣ θ ) p ( x ∣ θ ) ] − E p ( x ∣ θ ) [ ( ∇ p ( x ∣ θ ) p ( x ∣ θ ) ) ( ∇ p ( x ∣ θ ) p ( x ∣ θ ) ) T ] = ∫ H p ( x ∣ θ ) p ( x ∣ θ ) p ( x ∣ θ ) d x − E p ( x ∣ θ ) [ ∇ log p ( x ∣ θ ) ∇ log p ( x ∣ θ ) T ] = H ∫ p ( x ∣ θ ) d x − F = H 1 − F = − F .
\mathop{\mathbb{E}}_{p(x \vert \theta)} \left[ \text{H}_{\log p(x \vert \theta)} \right] = \mathop{\mathbb{E}}_{p(x \vert \theta)} \left[ \frac{\text{H}_{p(x \vert \theta)}}{p(x \vert \theta)} - \left( \frac{\nabla p(x \vert \theta)}{p(x \vert \theta)} \right) \left( \frac{\nabla p(x \vert \theta)}{p(x \vert \theta)} \right)^{\text{T}} \right] \\[5pt]
= \mathop{\mathbb{E}}_{p(x \vert \theta)} \left[ \frac{\text{H}_{p(x \vert \theta)}}{p(x \vert \theta)} \right] - \mathop{\mathbb{E}}_{p(x \vert \theta)} \left[ \left( \frac{\nabla p(x \vert \theta)}{p(x \vert \theta)} \right) \left( \frac{\nabla p(x \vert \theta)}{p(x \vert \theta)}\right)^{\text{T}} \right] \\[5pt]
= \int \frac{\text{H}_{p(x \vert \theta)}}{p(x \vert \theta)} p(x \vert \theta) \, \text{d}x \, - \mathop{\mathbb{E}}_{p(x \vert \theta)} \left[ \nabla \log p(x \vert \theta) \, \nabla \log p(x \vert \theta)^{\text{T}} \right] \\[5pt]
= \text{H}_{\int p(x \vert \theta) \, \text{d}x} \, - \text{F} \\[5pt]
= \text{H}_{1} - \text{F} \\[5pt]
= -\text{F} \, .
E p ( x ∣ θ ) [ H log p ( x ∣ θ ) ] = E p ( x ∣ θ ) [ p ( x ∣ θ ) H p ( x ∣ θ ) − ( p ( x ∣ θ ) ∇ p ( x ∣ θ ) ) ( p ( x ∣ θ ) ∇ p ( x ∣ θ ) ) T ] = E p ( x ∣ θ ) [ p ( x ∣ θ ) H p ( x ∣ θ ) ] − E p ( x ∣ θ ) [ ( p ( x ∣ θ ) ∇ p ( x ∣ θ ) ) ( p ( x ∣ θ ) ∇ p ( x ∣ θ ) ) T ] = ∫ p ( x ∣ θ ) H p ( x ∣ θ ) p ( x ∣ θ ) d x − E p ( x ∣ θ ) [ ∇ log p ( x ∣ θ ) ∇ log p ( x ∣ θ ) T ] = H ∫ p ( x ∣ θ ) d x − F = H 1 − F = − F .
因此我們得到了:F = − E p ( x ∣ θ ) [ H log p ( x ∣ θ ) ] \text{F} = -\mathop{\mathbb{E}}_{p(x \vert \theta)} \left[ \text{H}_{\log p(x \vert \theta)} \right] F = − E p ( x ∣ θ ) [ H log p ( x ∣ θ ) ] ,證明完畢。我們可以將F的作用看作是對數似然函數曲率的度量。一種很自然的想法就是,在二階優化算法中,比如牛頓法中,需要計算Hessian矩陣,那麼是否可以用Fisher矩陣來代替Hessian舉證呢?這就引出了下面要講的natural gradient方法了。
2. 自然梯度下降法Natural Gradient Descent
先來講一講parameter space和distribution space的概念,導致了對梯度下降的不同理解。
parameter space:一般我們解決優化問題最常用的方法是用梯度下降,每一步優化方向採用負梯度方向,− ∇ θ L ( θ ) -\nabla_\theta \mathcal{L}(\theta) − ∇ θ L ( θ ) 。可以知道,負梯度方向是在當前的參數值θ \theta θ 的local neighborhood裏loss在參數空間的最速下降方向。
− ∇ θ L ( θ ) ∥ ∇ θ L ( θ ) ∥ = lim ϵ → 0 1 ϵ arg min d s.t. ∥ d ∥ ≤ ϵ L ( θ + d ) .
\frac{-\nabla_\theta \mathcal{L}(\theta)}{\lVert \nabla_\theta \mathcal{L}(\theta) \rVert} = \lim_{\epsilon \to 0} \frac{1}{\epsilon} \mathop{\text{arg min}}_{d \text{ s.t. } \lVert d \rVert \leq \epsilon} \mathcal{L}(\theta + d) \, .
∥ ∇ θ L ( θ ) ∥ − ∇ θ L ( θ ) = ϵ → 0 lim ϵ 1 arg min d s.t. ∥ d ∥ ≤ ϵ L ( θ + d ) .
上面的表達式是,參數空間中最陡的下降方向是選取一個向量d d d ,使得新參數θ + d \theta+d θ + d 在當前參數θ \theta θ 的ϵ \epsilon ϵ -鄰域內,並且我們選取使損失最小的d d d 。注意我們用歐幾里德範數來表示這個鄰域。因此,梯度下降的優化依賴於參數空間的歐氏幾何度量。
distribution space:同時,如果我們的目標是最小化損失函數(最大化似然),那麼我們自然會在所有可能的似然空間中採取優化步驟,通過參數θ \theta θ 來實現。由於似然函數本身是一個概率分佈,我們稱它所在的空間爲分佈空間(distribution space)。因此,在分佈空間中採用最陡下降方向,而不是參數空間,是有道理的。
在distribution space中,用什麼距離度量呢?常用的選擇就是用KL散度(KL-divergence),KL散度常用語評估兩個分佈的接近程度。但是,實際上KL散度是不對稱的,因此理論上不是一個distance metric,但是呢,很多地方還是用KL散度來衡量兩個分佈的接近程度。(as d d d goes to zero, KL-divergence is asymptotically symmetric. So, within a local neighbourhood, KL-divergence is approximately symmetric [3].)
2.1 分佈空間中的最速下降,Natural gradient方法
前面講了那麼多,終於要引出自然梯度方法的基本推導了。
先推導KL散度的泰勒展開有如下形式:
KL [ p ( x ∣ θ ) ∥ p ( x ∣ θ + d ) ] ≈ 1 2 d T F d \text{KL}[p(x \vert \theta) \, \Vert \, p(x \vert \theta + d)] \approx \frac{1}{2} d^\text{T} \text{F} d KL [ p ( x ∣ θ ) ∥ p ( x ∣ θ + d ) ] ≈ 2 1 d T F d
證明 :寫出二階泰勒展開:
KL [ p ( x ∣ θ ) ∥ p ( x ∣ θ + d ) ] ≈ KL [ p ( x ∣ θ ) ∥ p ( x ∣ θ ′ ) ] ∣ θ ′ = θ + ( ∇ θ ′ KL [ p ( x ∣ θ ) ∥ p ( x ∣ θ ′ ) ] ∣ θ ′ = θ ) T d + 1 2 d T ∇ θ ′ 2 KL [ p ( x ∣ θ ) ∥ p ( x ∣ θ ′ ) ] ∣ θ ′ = θ d = KL [ p ( x ∣ θ ) ∥ p ( x ∣ θ ) ] − E p ( x ∣ θ ) [ ∇ θ log p ( x ∣ θ ) ] T d + 1 2 d T F d = 1 2 d T F d
\text{KL}[p(x \vert \theta) \, \Vert \, p(x \vert \theta+d)] \approx \text{KL}[p(x \vert \theta) \, \Vert \, p(x \vert \theta')]\vert_{\theta' = \theta} + (\left. \nabla_{\theta'} \text{KL}[p(x \vert \theta) \, \Vert \, p(x \vert \theta')] \right\vert_{\theta' = \theta})^\text{T} d + \frac{1}{2} d^\text{T} \nabla_{\theta'}^2 \, \text{KL}[p(x \vert \theta) \, \Vert \, p(x \vert \theta')]\vert_{\theta' = \theta}d \\[5pt]
=\text{KL}[p(x \vert \theta) \, \Vert \, p(x \vert \theta)] - \mathop{\mathbb{E}}_{p(x \vert \theta)} [ \nabla_\theta \log p(x \vert \theta) ]^\text{T} d + \frac{1}{2} d^\text{T} \text{F} d = \frac{1}{2} d^\text{T} \text{F} d\\[5pt]
KL [ p ( x ∣ θ ) ∥ p ( x ∣ θ + d ) ] ≈ KL [ p ( x ∣ θ ) ∥ p ( x ∣ θ ′ ) ] ∣ θ ′ = θ + ( ∇ θ ′ KL [ p ( x ∣ θ ) ∥ p ( x ∣ θ ′ ) ] ∣ θ ′ = θ ) T d + 2 1 d T ∇ θ ′ 2 KL [ p ( x ∣ θ ) ∥ p ( x ∣ θ ′ ) ] ∣ θ ′ = θ d = KL [ p ( x ∣ θ ) ∥ p ( x ∣ θ ) ] − E p ( x ∣ θ ) [ ∇ θ log p ( x ∣ θ ) ] T d + 2 1 d T F d = 2 1 d T F d
這樣理解爲什麼引入θ ′ \theta' θ ′ :把KL散度第一個p ( x ∣ θ ) p(x \vert \theta) p ( x ∣ θ ) 看成一個確定的分佈,而變化的是在第二個分佈的參數上。我們依次來看下約等號≈ \approx ≈ 後面這三項:
泰勒展開的第一項 KL [ p θ ∥ p θ ] = 0 \text{KL}[p_{\theta} \, \Vert \, p_{\theta}] = 0 KL [ p θ ∥ p θ ] = 0
第二項的推導:
∇ θ ′ KL [ p ( x ∣ θ ) ∥ p ( x ∣ θ ′ ) ] = ∇ θ ′ E p ( x ∣ θ ) [ log p ( x ∣ θ ) ] − ∇ θ ′ E p ( x ∣ θ ) [ log p ( x ∣ θ ′ ) ] = − E p ( x ∣ θ ) [ ∇ θ ′ log p ( x ∣ θ ′ ) ] = 0
\nabla_{\theta'} \text{KL}[p(x \vert \theta) \, \Vert \, p(x \vert \theta')] = \nabla_{\theta'} \mathop{\mathbb{E}}_{p(x \vert \theta)} [ \log p(x \vert \theta) ] - \nabla_{\theta'} \mathop{\mathbb{E}}_{p(x \vert \theta)} [ \log p(x \vert \theta') ] \\[8pt]
= - \mathop{\mathbb{E}}_{p(x \vert \theta)} [ \nabla_{\theta'} \log p(x \vert \theta') ] =0\\[5pt]
∇ θ ′ KL [ p ( x ∣ θ ) ∥ p ( x ∣ θ ′ ) ] = ∇ θ ′ E p ( x ∣ θ ) [ log p ( x ∣ θ ) ] − ∇ θ ′ E p ( x ∣ θ ) [ log p ( x ∣ θ ′ ) ] = − E p ( x ∣ θ ) [ ∇ θ ′ log p ( x ∣ θ ′ ) ] = 0
考慮∣ θ ′ = θ \vert_{\theta' = \theta} ∣ θ ′ = θ 的話,第二項包含了Score function的期望。正好是本章節前面Fisher Matrix部分講過的,Score function的期望,已經證明過是0。
第三項,需要用到前面第一章證明過的,F = − E p ( x ∣ θ ) [ H log p ( x ∣ θ ) ] \text{F} = -\mathop{\mathbb{E}}_{p(x \vert \theta)} \left[ \text{H}_{\log p(x \vert \theta)} \right] F = − E p ( x ∣ θ ) [ H log p ( x ∣ θ ) ] ,以及如下性質:Fisher Information Matrix F is the Hessian of KL-divergence between two distributions p ( x ∣ θ ) p(x \vert \theta) p ( x ∣ θ ) and p ( x ∣ θ ′ ) p(x \vert \theta') p ( x ∣ θ ′ ) , with respect to θ ′ \theta' θ ′ , evaluated at θ ′ = θ \theta' = \theta θ ′ = θ ,下面是推導過程:
KL [ p ( x ∣ θ ) ∥ p ( x ∣ θ ′ ) ] = E p ( x ∣ θ ) [ log p ( x ∣ θ ) ] − E p ( x ∣ θ ) [ log p ( x ∣ θ ′ ) ]
\text{KL} [p(x \vert \theta) \, \Vert \, p(x \vert \theta')] = \mathop{\mathbb{E}}_{p(x \vert \theta)} [ \log p(x \vert \theta) ] - \mathop{\mathbb{E}}_{p(x \vert \theta)} [ \log p(x \vert \theta') ]
KL [ p ( x ∣ θ ) ∥ p ( x ∣ θ ′ ) ] = E p ( x ∣ θ ) [ log p ( x ∣ θ ) ] − E p ( x ∣ θ ) [ log p ( x ∣ θ ′ ) ]
The first derivative wrt. θ ′ \theta' θ ′ is:
∇ θ ′ KL [ p ( x ∣ θ ) ∥ p ( x ∣ θ ′ ) ] = ∇ θ ′ E p ( x ∣ θ ) [ log p ( x ∣ θ ) ] − ∇ θ ′ E p ( x ∣ θ ) [ log p ( x ∣ θ ′ ) ] = − E p ( x ∣ θ ) [ ∇ θ ′ log p ( x ∣ θ ′ ) ] = − ∫ p ( x ∣ θ ) ∇ θ ′ log p ( x ∣ θ ′ ) d x
\nabla_{\theta'} \text{KL}[p(x \vert \theta) \, \Vert \, p(x \vert \theta')] = \nabla_{\theta'} \mathop{\mathbb{E}}_{p(x \vert \theta)} [ \log p(x \vert \theta) ] - \nabla_{\theta'} \mathop{\mathbb{E}}_{p(x \vert \theta)} [ \log p(x \vert \theta') ] \\[5pt]
= - \mathop{\mathbb{E}}_{p(x \vert \theta)} [ \nabla_{\theta'} \log p(x \vert \theta') ] \\[5pt]
= - \int p(x \vert \theta) \nabla_{\theta'} \log p(x \vert \theta') \, \text{d}x ∇ θ ′ KL [ p ( x ∣ θ ) ∥ p ( x ∣ θ ′ ) ] = ∇ θ ′ E p ( x ∣ θ ) [ log p ( x ∣ θ ) ] − ∇ θ ′ E p ( x ∣ θ ) [ log p ( x ∣ θ ′ ) ] = − E p ( x ∣ θ ) [ ∇ θ ′ log p ( x ∣ θ ′ ) ] = − ∫ p ( x ∣ θ ) ∇ θ ′ log p ( x ∣ θ ′ ) d x
The second derivative is:
∇ θ ′ 2 KL [ p ( x ∣ θ ) ∥ p ( x ∣ θ ′ ) ] ∣ θ ′ = θ = − ∫ p ( x ∣ θ ) ∇ θ ′ 2 log p ( x ∣ θ ′ ) ∣ θ ′ = θ d x = − ∫ p ( x ∣ θ ) H log p ( x ∣ θ ) d x = − E p ( x ∣ θ ) [ H log p ( x ∣ θ ) ] = F
\nabla_{\theta'}^2 \, \text{KL}[p(x \vert \theta) \, \Vert \, p(x \vert \theta')]\vert_{\theta' = \theta} = - \int p(x \vert \theta) \, \nabla_{\theta'}^2 \log p(x \vert \theta')\vert_{\theta' = \theta} \, \text{d}x \\[5pt]
= - \int p(x \vert \theta) \, \text{H}_{\log p(x \vert \theta)} \, \text{d}x \\[5pt]
= - \mathop{\mathbb{E}}_{p(x \vert \theta)} [\text{H}_{\log p(x \vert \theta)}] \\[5pt]
= \text{F}
∇ θ ′ 2 KL [ p ( x ∣ θ ) ∥ p ( x ∣ θ ′ ) ] ∣ θ ′ = θ = − ∫ p ( x ∣ θ ) ∇ θ ′ 2 log p ( x ∣ θ ′ ) ∣ θ ′ = θ d x = − ∫ p ( x ∣ θ ) H log p ( x ∣ θ ) d x = − E p ( x ∣ θ ) [ H log p ( x ∣ θ ) ] = F
所以得到KL散度的二階泰勒展開形式:
KL [ p ( x ∣ θ ) ∥ p ( x ∣ θ + d ) ] ≈ 1 2 d T F d \text{KL}[p(x \vert \theta) \, \Vert \, p(x \vert \theta + d)] \approx \frac{1}{2} d^\text{T} \text{F} d KL [ p ( x ∣ θ ) ∥ p ( x ∣ θ + d ) ] ≈ 2 1 d T F d
現在,我們想知道什麼是使分佈空間中的損失函數L最小化的更新向量d,以便我們知道哪個方向的KL散度減小得最多。這類似於最速下降法,但在以KL散度爲度量的分佈空間,而不是通常的以歐氏度量的參數空間。爲此,我們將最小化:
d ∗ = arg min d s.t. KL [ p θ ∥ p θ + d ] ≤ c L ( θ + d ) ,
d^* = \mathop{\text{arg min}}_{d \text{ s.t. } \text{KL}[p_\theta \Vert p_{\theta + d}] \leq c} \mathcal{L} (\theta + d) \, ,
d ∗ = arg min d s.t. KL [ p θ ∥ p θ + d ] ≤ c L ( θ + d ) ,
如果我們寫出上面的最小化問題在拉格朗日乘子法形式,用二階泰勒展開近似KL散度,用一階泰勒級數展開近似L \mathcal{L} L :
d ∗ = arg min d L ( θ + d ) + λ ( KL [ p θ ∥ p θ + d ] − c ) ≈ arg min d L ( θ ) + ∇ θ L ( θ ) T d + 1 2 λ d T F d − λ c
d^* = \mathop{\text{arg min}}_d \, \mathcal{L} (\theta + d) + \lambda \, (\text{KL}[p_\theta \Vert p_{\theta + d}] - c) \\[8pt]
\approx \mathop{\text{arg min}}_d \, \mathcal{L}(\theta) + \nabla_\theta \mathcal{L}(\theta)^\text{T} d + \frac{1}{2} \lambda \, d^\text{T} \text{F} d - \lambda c
d ∗ = arg min d L ( θ + d ) + λ ( KL [ p θ ∥ p θ + d ] − c ) ≈ arg min d L ( θ ) + ∇ θ L ( θ ) T d + 2 1 λ d T F d − λ c
其中λ \lambda λ 是拉格朗日系數,要求解這個優化問題,我們求d d d 的梯度等於0:
0 = ∂ ∂ d [ L ( θ ) + ∇ θ L ( θ ) T d + 1 2 λ d T F d − λ c ] = ∇ θ L ( θ ) + λ F d λ F d = − ∇ θ L ( θ ) d = − 1 λ F − 1 ∇ θ L ( θ )
0 = \frac{\partial}{\partial d} \left[\mathcal{L}(\theta) + \nabla_\theta \mathcal{L}(\theta)^\text{T} d + \frac{1}{2} \lambda \, d^\text{T} \text{F} d - \lambda c\right] \\[8pt]
= \nabla_\theta \mathcal{L}(\theta) + \lambda \, \text{F} d \\[8pt]
\lambda \, \text{F} d = -\nabla_\theta \mathcal{L}(\theta) \\[8pt]
d = -\frac{1}{\lambda} \text{F}^{-1} \nabla_\theta \mathcal{L}(\theta) \\[8pt]
0 = ∂ d ∂ [ L ( θ ) + ∇ θ L ( θ ) T d + 2 1 λ d T F d − λ c ] = ∇ θ L ( θ ) + λ F d λ F d = − ∇ θ L ( θ ) d = − λ 1 F − 1 ∇ θ L ( θ )
因此,先不看1 λ \frac{1}{\lambda} λ 1 (可以一起考慮吸收到learning rate部分),我們得到在分佈空間中,最優的更新方向是− F − 1 ∇ θ L ( θ ) -\text{F}^{-1} \nabla_\theta \mathcal{L}(\theta) − F − 1 ∇ θ L ( θ ) 。(類比二階優化方法的牛頓法,更新方向是− H − 1 ∇ θ L ( θ ) -\text{H}^{-1} \nabla_\theta \mathcal{L}(\theta) − H − 1 ∇ θ L ( θ ) ,非常類似吧)。
我們把Natural gradient 定義成:∇ ~ θ L ( θ ) = F − 1 ∇ θ L ( θ ) \tilde{\nabla}_\theta \mathcal{L}(\theta) = \text{F}^{-1} \nabla_\theta \mathcal{L}(\theta) ∇ ~ θ L ( θ ) = F − 1 ∇ θ L ( θ ) . 自然梯度下降算法的基本流程如下:(一般我們會採用batch模式的Empirical Fisher Matrix:F = 1 N ∑ i = 1 N ∇ log p ( x i ∣ θ ) ∇ log p ( x i ∣ θ ) T \text{F} = \frac{1}{N} \sum_{i=1}^{N} \nabla \log p(x_i \vert \theta) \, \nabla \log p(x_i \vert \theta)^{\text{T}} F = N 1 ∑ i = 1 N ∇ log p ( x i ∣ θ ) ∇ log p ( x i ∣ θ ) T )
與Adam關係的類比討論
在數據量較少的非常簡單的模型中,我們看到可以很容易地實現自然梯度下降。但衆所周知,深度學習模型中的參數數目非常大,千萬甚至億級參數量模型很常見,即使一層都有上百萬參數。這類模型的Fisher信息矩陣難以計算、存儲、以及求逆。這和二階優化方法在深度學習中不受歡迎的原因是一樣的。
解決這個問題的一種方法是計算近似的Fisher/Hessian。像ADAM[5]這樣的方法計算梯度的一階和二階moving average(m和v)。m是動量momentum,這裏不討論。而v可以看成是Fisher信息矩陣的近似——但將其約束爲對角矩陣(協方差的對角線元素是梯度的平方)。因此,在ADAM中,我們只需要O ( n ) O(n) O ( n ) 空間來存儲(F的近似值)而不是O ( n 2 ) O(n^2) O ( n 2 ) ,並且可以在O ( n ) O(n) O ( n ) 而不是O ( n 3 ) O(n^3) O ( n 3 ) 中進行求逆運算。在實踐中,ADAM工作得非常好,是目前優化深層神經網絡的基準優化方法。
OK,這一篇終於基本寫好了,後面會繼續這個話題,再記錄一下如何加速自然梯度方法的工作,主要是比較知名的K-FAC算法。這篇可能還有一些關於自然梯度的引申討論,過幾天再補。參考[6][7]。TBD…
參考資料
[1] https://wiseodd.github.io/techblog/2018/03/11/fisher-information/
[2] https://wiseodd.github.io/techblog/2018/03/14/natural-gradient/
[3] Martens, James. “New insights and perspectives on the natural gradient method.” arXiv preprint arXiv:1412.1193 (2014).
[4] Ly, Alexander, et al. “A tutorial on Fisher information.” Journal of Mathematical Psychology 80 (2017): 40-55
[5] ADAM A METHOD FOR STOCHASTIC OPTIMIZATION. 2015
[6] 多角度理解自然梯度,https://zhuanlan.zhihu.com/p/82934100
[7] 如何理解 natural gradient descent?,https://www.zhihu.com/question/266846405