優化與深度學習
優化與估計
儘管優化方法可以最小化深度學習中的損失函數值,但本質上優化方法達到的目標與深度學習的目標並不相同。
- 優化方法目標:訓練集損失函數值
- 深度學習目標:測試集損失函數值(泛化性)
%matplotlib inline
import sys
sys.path.append('/home/kesci/input')
import d2lzh1981 as d2l
from mpl_toolkits import mplot3d # 三維畫圖3d圖
import numpy as np
def f(x): return x * np.cos(np.pi * x)
def g(x): return f(x) + 0.2 * np.cos(5 * np.pi * x)
d2l.set_figsize((5, 3))
x = np.arange(0.5, 1.5, 0.01)
fig_f, = d2l.plt.plot(x, f(x),label="train error")
fig_g, = d2l.plt.plot(x, g(x),'--', c='purple', label="test error")
fig_f.axes.annotate('empirical risk', (1.0, -1.2), (0.5, -1.1),arrowprops=dict(arrowstyle='->'))
fig_g.axes.annotate('expected risk', (1.1, -1.05), (0.95, -0.5),arrowprops=dict(arrowstyle='->'))
d2l.plt.xlabel('x')
d2l.plt.ylabel('risk')
d2l.plt.legend(loc="upper right")
# str是給數據點添加註釋的內容,支持輸入一個字符串
# xy=是要添加註釋的數據點的位置
# xytext=是註釋內容的位置
# bbox=是註釋框的風格和顏色深度,fc越小,註釋框的顏色越深,支持輸入一個字典
# va="center", ha="center"表示註釋的座標以註釋框的正中心爲準,而不是註釋框的左下角(v代表垂直方向,h代表水平方向)
# xycoords和textcoords可以指定數據點的座標系和註釋內容的座標系,通常只需指定xycoords即可,textcoords默認和xycoords相同
# arrowprops可以指定箭頭的風格支持,輸入一個字典
# 所以總體來說,我們的目標是爲了找到測試集合的最小損失
<matplotlib.legend.Legend at 0x7f43ecf200b8>
優化在深度學習中的挑戰
- 局部最小值
- 鞍點
- 梯度消失
局部最小值
def f(x):
return x * np.cos(np.pi * x)
d2l.set_figsize((4.5, 2.5))
x = np.arange(-1.0, 2.0, 0.1)
fig, = d2l.plt.plot(x, f(x))
fig.axes.annotate('local minimum', xy=(-0.3, -0.25), xytext=(-0.77, -1.0),
arrowprops=dict(arrowstyle='->'))
fig.axes.annotate('global minimum', xy=(1.1, -0.95), xytext=(0.6, 0.8),
arrowprops=dict(arrowstyle='->'))
d2l.plt.xlabel('x')
d2l.plt.ylabel('f(x)');
# 局部最小值和全局最小值
鞍點
- 鞍點(Saddle point)在微分方程中,沿着某一方向是穩定的,另一條方向是不穩定的奇點,叫做鞍點。 在泛函中,既不是極大值點也不是極小值點的臨界點,叫做鞍點。 在矩陣中,一個數在所在行中是最大值,在所在列中是最小值,則被稱爲鞍點。 在物理上要廣泛一些,指在一個方向是極大值,另一個方向是極小值的點。
x = np.arange(-2.0, 2.0, 0.1)
fig, = d2l.plt.plot(x, x**3)
fig.axes.annotate('saddle point', xy=(0, -0.2), xytext=(-0.52, -5.0),
arrowprops=dict(arrowstyle='->'))
d2l.plt.xlabel('x')
d2l.plt.ylabel('f(x)');
e.g.
x, y = np.mgrid[-1: 1: 31j, -1: 1: 31j]
# print(x)
# print(y)
# 一個是行展開,一個是列展開
z = x**2 - y**2
d2l.set_figsize((6, 4))
ax = d2l.plt.figure().add_subplot(111, projection='3d')
ax.plot_wireframe(x, y, z, **{'rstride': 2, 'cstride': 2})
ax.plot([0], [0], [0], 'ro', markersize=10)
ticks = [-1, 0, 1]
d2l.plt.xticks(ticks)
d2l.plt.yticks(ticks)
ax.set_zticks(ticks)
d2l.plt.xlabel('x')
d2l.plt.ylabel('y');
#從下圖看,從一個方向看,他是極大值點,另一個方向看,它是極小值點
梯度消失
x = np.arange(-2.0, 5.0, 0.01)
fig, = d2l.plt.plot(x, np.tanh(x))
d2l.plt.xlabel('x')
d2l.plt.ylabel('f(x)')
fig.axes.annotate('vanishing gradient', (4, 1), (2, 0.0) ,arrowprops=dict(arrowstyle='->'))
Text(2, 0.0, 'vanishing gradient')
凸性 (Convexity):兩點連線上的點也在圖形內部
基礎
集合
函數
def f(x):
return 0.5 * x**2 # Convex
def g(x):
return np.cos(np.pi * x) # Nonconvex
def h(x):
return np.exp(0.5 * x) # Convex
x, segment = np.arange(-2, 2, 0.01), np.array([-1.5, 1])
d2l.use_svg_display()
_, axes = d2l.plt.subplots(1, 3, figsize=(9, 3))
for ax, func in zip(axes, [f, g, h]):
ax.plot(x, func(x))
ax.plot(segment, func(segment),'--', color="purple")
# d2l.plt.plot([x, segment], [func(x), func(segment)], axes=ax)
Jensen 不等式:數學歸納法證明
性質
- 無局部極小值
- 與凸集的關係
- 二階條件
無局部最小值
證明:假設存在 是局部最小值,則存在全局最小值 , 使得 , 則對 :
與凸集的關係
對於凸函數 ,定義集合 ,則集合 爲凸集
證明:對於點 , 有 , 故
x, y = np.meshgrid(np.linspace(-1, 1, 101), np.linspace(-1, 1, 101),
indexing='ij')
z = x**2 + 0.5 * np.cos(2 * np.pi * y)
# Plot the 3D surface
d2l.set_figsize((6, 4))
ax = d2l.plt.figure().add_subplot(111, projection='3d')
ax.plot_wireframe(x, y, z, **{'rstride': 10, 'cstride': 10})
ax.contour(x, y, z, offset=-1)
ax.set_zlim(-1, 1.5)
# Adjust labels
for func in [d2l.plt.xticks, d2l.plt.yticks, ax.set_zticks]:
func([-1, 0, 1])
凸函數與二階導數
是凸函數
必要性 ():
對於凸函數:
故:
充分性 ():
令 爲 上的三個點,由拉格朗日中值定理:
根據單調性,有 , 故:
def f(x):
return 0.5 * x**2
x = np.arange(-2, 2, 0.01)
axb, ab = np.array([-1.5, -0.5, 1]), np.array([-1.5, 1])
d2l.set_figsize((3.5, 2.5))
fig_x, = d2l.plt.plot(x, f(x))
fig_axb, = d2l.plt.plot(axb, f(axb), '-.',color="purple")
fig_ab, = d2l.plt.plot(ab, f(ab),'g-.')
fig_x.axes.annotate('a', (-1.5, f(-1.5)), (-1.5, 1.5),arrowprops=dict(arrowstyle='->'))
fig_x.axes.annotate('b', (1, f(1)), (1, 1.5),arrowprops=dict(arrowstyle='->'))
fig_x.axes.annotate('x', (-0.5, f(-0.5)), (-1.5, f(-0.5)),arrowprops=dict(arrowstyle='->'))
Text(-1.5, 0.125, 'x')
限制條件
拉格朗日乘子法
懲罰項
欲使 , 將項 加入目標函數,如多層感知機章節中的
投影
總結
- 鞍點是對所有自變量一階偏導數都爲0,且Hessian矩陣特徵值有正有負的點,代表在某一個方向是極大值點,某一個方向是極小值點
- 假設A和B都是凸集合,那以下是凸集合的是:A和B的交集