深度學習知識點(3):優化改進版的梯度下降法

目錄

1、Adagrad法

2、RMSprop法

3、Momentum法

4、Adam法

參考資料:


發展歷史簡括:

標準梯度下降法的缺陷:

如果學習率選的不恰當會出現以上情況。

因此有一些自動調學習率的方法。一般來說,隨着迭代次數的增加,學習率應該越來越小,因爲迭代次數增加後,得到的解應該比較靠近最優解,所以要縮小步長η,那麼有什麼公式嗎?比如:,但是這樣做後,所有參數更新時仍都採用同一個學習率,即學習率不能適應所有的參數更新。

解決方案是:給不同的參數不同的學習率

1、Adagrad法

假設N元函數f(x),針對一個自變量研究Adagrad梯度下降的迭代過程,

可以看出,Adagrad算法中有自適應調整梯度的意味(adaptive gradient),學習率需要除以一個東西,這個東西就是前n次迭代過程中偏導數的平方和再加一個常量最後開根號

舉例:使用Adagrad算法求y = x2的最小值點

導函數爲g(x) = 2x

初始化x(0) = 4,學習率η=0.25,ε=0.1

第①次迭代:

第②次迭代:

 

第③次迭代:

 

求解的過程如下圖所示

對應代碼爲:

from matplotlib import pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure()
x = np.arange(-4, 4, 0.025)
plt.plot(x,x**2)
plt.title("y = x^2")
def f(x):
    return x**2
def h(x):
    return 2*x
η = 0.25
ε = 0.1
x = 4
iters = 0
sum_square_grad = 0
X = []
Y = []
while iters<12:
    iters+=1
    X.append(x)
    Y.append(f(x))
    sum_square_grad += h(x)**2
    x = x - η/np.sqrt(sum_square_grad+ε)*h(x)
    print(iters,x)
plt.plot(X,Y,"ro")
ax = plt.subplot()
for i in range(len(X)):
    ax.text(X[i], (X[i])**2, "({:.3f},{:.3f})".format(X[i], (X[i])**2), color='red')
plt.show()

缺點:由於分母是累加梯度的平方,到後面累加的比較大時,會導致梯度更新緩慢

2、RMSprop法

AdaGrad算法在迭代後期由於學習率過小,可能較難找到一個有用的解。爲了解決這一問題,RMSprop算法對Adagrad算法做了一點小小的修改,RMSprop使用指數衰減只保留過去給定窗口大小的梯度,使其能夠在找到凸碗狀結構後快速收斂。

假設N元函數f(x),針對一個自變量研究RMSprop梯度下降的迭代過程,

可以看出分母不再是一味的增加,它會重點考慮距離他較近的梯度(指數衰減的效果),也就不會出現Adagrad到後期收斂緩慢的問題

舉例:使用RMSprop算法求y = x2的最小值點

導函數爲h(x) = 2x

初始化g(0) = 1,x(0) = 4,ρ=0.9,η=0.01,ε=10-10

第①次迭代:

 

第②次迭代:

 

求解的過程如下圖所示

對應代碼爲:

from matplotlib import pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure()
x = np.arange(-4, 4, 0.025)
plt.plot(x,x**2)
plt.title("y = x^2")
def f(x):
    return x**2
def h(x):
    return 2*x
g = 1
x = 4
ρ = 0.9
η = 0.01
ε = 10e-10
iters = 0
X = []
Y = []
while iters<12:
    iters+=1
    X.append(x)
    Y.append(f(x))
    g = ρ*g+(1-ρ)*h(x)**2
    x = x - η/np.sqrt(g+ε)*h(x)
    print(iters,x)
plt.plot(X,Y,"ro")
ax = plt.subplot()
for i in range(len(X)):
    ax.text(X[i], (X[i])**2, "({:.3f},{:.3f})".format(X[i], (X[i])**2), color='red')
plt.show()

3、Momentum法

Momentum是動量的意思,想象一下,一個小車從高坡上衝下來,他不會停在最低點,因爲他還有一個動量,還會向前衝,甚至可以衝過一些小的山丘,如果面對的是較大的坡,他可能爬不上去,最終又會倒車回來,摺疊幾次,停在谷底。

如果使用的是沒有動量的梯度下降法,則可能會停到第一個次優解

最直觀的理解就是,若當前的梯度方向與累積的歷史梯度方向一致,則當前的梯度會被加強,從而這一步下降的幅度更大。若當前的梯度方向與累積的梯度方向不一致,則會減弱當前下降的梯度幅度。

從這幅圖可以看出來,當小球到達A點處,負梯度方向的紅箭頭朝着x軸負向,但是動量方向(綠箭頭)朝着x軸的正向並且長度大於紅箭頭,因此小球在A處還會朝着x軸正向移動。

下面正式介紹Momentum法

假設N元函數f(x),針對一個自變量研究Momentum梯度下降的迭代過程,

v表示動量,初始v=0

α是一個接近於1的數,一般設置爲0.9,也就是把之前的動量縮減到0.9倍

η是學習率

下面通過一個例子演示一下,求y = 2*x^4-x^3-x^2的極小值點

可以看出從-0.8開始迭代,依靠動量成功越過第一個次優解,發現無法越過最優解,摺疊回來,最終收斂到最優解。

對應代碼如下:

from matplotlib import pyplot as plt
import numpy as np

fig = plt.figure()
x = np.arange(-0.8, 1.2, 0.025)
plt.plot(x,-x**3-x**2+2*x**4)
plt.title("y = 2*x^4-x^3-x^2")
def f(x):
    return 2*x**4-x**3-x**2
def h(x):
    return 8*x**3 - 3*x**2 - 2*x
η = 0.05
α = 0.9
v = 0
x = -0.8
iters = 0
X = []
Y = []
while iters<12:
    iters+=1
    X.append(x)
    Y.append(f(x))
    v = α*v - η*h(x)
    x = x + v
    print(iters,x)
plt.plot(X,Y)
plt.show()

4、Adam法

Adam實際上是把momentum和RMSprop結合起來的一種算法

假設N元函數f(x),針對一個自變量研究Adam梯度下降的迭代過程,

下面依次解釋這五個式子:

在①式中,注意m(n)是反向的動量與梯度的和(而在Momentum中是正向動量與負梯度的和,因此⑤式對應的是減號)

在②式中,借鑑的是RMSprop的指數衰減

③和④式目的是糾正偏差

⑤式進行梯度更新

舉例:使用Adagrad算法求y = x2的最小值點

導函數爲h(x) = 2x

初始化x(0) = 4,m(0) = 0,v(0) = 0,β1=0.9,β2=0.999,ε=10-8,η = 0.001

第①次迭代:

 

第②次迭代:

 

求解的過程如下圖所示

 

對應代碼爲:

from matplotlib import pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure()
x = np.arange(-4, 4, 0.025)
plt.plot(x,x**2)
plt.title("y = x^2")
def f(x):
    return x**2
def h(x):
    return 2*x
x = 4
m = 0
v = 0
β1 = 0.9
β2 = 0.999
η = 0.001
ε = 10e-8
iters = 0
X = []
Y = []
while iters<12:
    iters+=1
    X.append(x)
    Y.append(f(x))
    m = β1*m + (1-β1)*h(x)
    v = β2*v + (1-β2)*h(x)**2
    m_het = m/(1-β1**iters)
    v_het = v/(1-β2**iters)
    x = x - η/np.sqrt(v_het+ε)*m_het
    print(iters,x)
plt.plot(X,Y,"ro")
ax = plt.subplot()
for i in range(len(X)):
    ax.text(X[i], (X[i])**2, "({:.3f},{:.3f})".format(X[i], (X[i])**2), color='red')
plt.show()

參考資料:

https://www.cnblogs.com/itmorn/p/11123789.html

https://blog.csdn.net/u012328159/article/details/80311892

 

 

 

 

 

 

 

 

 

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章