多项式曲线拟合
Polynomial Curve Fitting
实验目标
利用Python实现多项式的曲线拟合。
实现过程
- Step 1 :生成观测集和目标函数
假设训练集由x的N次观测 得到,x均匀分布于区间[0,1]。对应的观测集为 ,目标函数为。
所以,为了通过训练集和观测集拟合出预测函数,使其尽可能接近目标函数,我们通过训练集加上随机高斯噪声输入到目标函数得到。
首先,图一中分别绘制了 标准曲线(如绿线所示)和添加了噪声的观测集(样本包含10个点,如蓝点所示)。
import numpy as np
import matplotlib.pyplot as plt
#标准曲线
x = np.linspace(0, 1, 100)
t = np.sin(2 * np.pi * x)
#采样函数
def get_data(N):
x_n = np.linspace(0,1,N)
t_n = np.sin(2 * np.pi * x_n) + np.random.normal(scale=0.15, size=N) #add Gaussian Noise
return x_n, t_n
#绘制部分组件函数
def draw_ticks():
plt.tick_params(labelsize=15)
plt.xticks(np.linspace(0, 1, 2))
plt.yticks(np.linspace(-1, 1, 3))
plt.ylim(-1.5, 1.5)
font = {'family':'Times New Roman','size':20}
plt.xlabel('x', font)
plt.ylabel('t',font, rotation='horizontal')
#采样
x_10, t_10 = get_data(10)
#图像绘制部分
plt.figure(1, figsize=(8,5))
plt.plot(x, t, 'g',linewidth=3)
plt.scatter(x_10, t_10, color='', marker='o', edgecolors='b', s=100, linewidth=3, label="training data")
draw_ticks()
plt.title('Figure 1 : sample curve', font)
plt.savefig('1.png', dpi=400)
绿色的曲线为要拟合的目标函数。然后,使用多项式函数来拟合生成的数据。多项式定义如下:
M是多项式的阶数,ω0,…,ωM 是多项式的系数,记为W。然后使用均方误差作为误差函数对拟合出的多项式进行评估,公式如下:
表示为矩阵形式:
拟合数据的目的即为最小化误差函数,因为误差函数是多项式系数W的二次函数,所以存在唯一最小值,且在导数为零处取得。对W求导并令导数为零得到:
故可以通过矩阵运算得到W。
#拟合函数(lamda默认为0,即无正则项)
def regress(M, N, x, x_n, t_n, lamda=0):
print("-----------------------M=%d, N=%d-------------------------" %(M,N))
order = np.arange(M+1)
order = order[:, np.newaxis]
e = np.tile(order, [1,N])
XT = np.power(x_n, e)
X = np.transpose(XT)
a = np.matmul(XT, X) + lamda*np.identity(M+1) #X.T * X
b = np.matmul(XT, t_n) #X.T * T
w = np.linalg.solve(a,b) #aW = b => (X.T * X) * W = X.T * T
print("W:")
print(w)
e2 = np.tile(order, [1,x.shape[0]])
XT2 = np.power(x, e2)
p = np.matmul(w, XT2)
return p
- Step 2 :比较不同阶数多项式的拟合效果
分别选择 M = 0, 1, 3, 9 不同多项式阶数对数据进行拟合。图中红线为拟合结果。
#M=0, N=10
p = regress(0, 10, x, x_10, t_10)
#图像绘制部分
plt.figure(2, figsize=(8,5))
plt.plot(x, t, 'g', x, p, 'r',linewidth=3)
plt.scatter(x_10, t_10, color='', marker='o', edgecolors='b', s=100, linewidth=3)
draw_ticks()
plt.title('Figure 2 : M = 0, N = 10', font)
plt.text(0.8, 0.9,'M = 0', font, style = 'italic')
plt.savefig('2.png', dpi=400)
#M=1, N=10
p = regress(1, 10, x, x_10, t_10)
#图像绘制部分
plt.figure(3, figsize=(8,5))
plt.plot(x, t, 'g', x, p, 'r',linewidth=3)
plt.scatter(x_10, t_10, color='', marker='o', edgecolors='b', s=100, linewidth=3)
draw_ticks()
plt.title('Figure 3 : M = 1, N = 10', font)
plt.text(0.8, 0.9,'M = 1', font, style = 'italic')
plt.savefig('3.png', dpi=400)
#M=3, N=10
p = regress(3, 10, x, x_10, t_10)
#图像绘制部分
plt.figure(4, figsize=(8,5))
plt.plot(x, t, 'g', x, p, 'r',linewidth=3)
plt.scatter(x_10, t_10, color='', marker='o', edgecolors='b', s=100, linewidth=3)
draw_ticks()
plt.title('Figure 4 : M = 3, N = 10', font)
plt.text(0.8, 0.9,'M = 3', font, style = 'italic')
plt.savefig('4.png', dpi=400)
#M=9, N=10
p = regress(9, 10, x, x_10, t_10)
#图像绘制部分
plt.figure(5, figsize=(8,5))
plt.plot(x, t, 'g', x, p, 'r',linewidth=3)
plt.scatter(x_10, t_10, color='', marker='o', edgecolors='b', s=100, linewidth=3)
draw_ticks()
plt.text(0.8, 0.9,'M = 9', font, style = 'italic')
plt.title('Figure 5 : M = 9, N = 10', font)
plt.savefig('5.png', dpi=400)
拟合结果显示:
- 当 M = 0 和 1 时,多项式的拟合效果很差,无法代表目标函数,即欠拟合现象。
- 当 M = 3 时,多项式已经比较接近目标函数。
- 当 M = 9 时,多项式函数精确地通过每个观测点,但是曲线呈现震荡形式并对噪声敏感,出现过拟合现象。
其中,欠拟合和过拟合都无法代表目标函数。
- Step 3 :通过增大数据规模改善过拟合现象
当模型复杂度确定时,考虑利用更多的观测点(15个和100个)对9阶多项式进行拟合。
M=9
N=15
x_15, t_15 = get_data(N)
p = regress(M, N, x, x_15, t_15)
#图像绘制部分
plt.figure(6, figsize=(8,5))
plt.plot(x, t, 'g', x, p, 'r',linewidth=3)
plt.scatter(x_15, t_15, color='', marker='o', edgecolors='b', s=100, linewidth=3)
draw_ticks()
plt.text(0.8, 0.65,'N = 15', font, style = 'italic')
plt.title('Figure 6 : M = 9, N = 15', font)
plt.savefig('6.png', dpi=400)
M=9
N=100
x_100, t_100 = get_data(N)
p = regress(M, N, x, x_100, t_100)
#图像绘制部分
plt.figure(7, figsize=(8,5))
plt.plot(x, t, 'g', x, p, 'r',linewidth=3)
plt.scatter(x_100, t_100, color='', marker='o', edgecolors='b', s=100, linewidth=3)
draw_ticks()
plt.text(0.8, 0.65,'N = 100', font, style = 'italic')
plt.title('Figure 7 : M = 9, N = 100', font)
plt.savefig('7.png', dpi=400)
可以看到,数据规模的增加能够有效的减轻模型的过拟合问题。但是实际应用中可能无法获得足够数据量。
- Step 4 : 通过正则化改善过拟合现象
除了增加数据量来减轻过拟合的影响,还可以通过正则化方法。在定义误差函数时增加惩罚项,使多项式系数被有效控制,不会过大。
误差函数变为如下形式:
求导置零得到:
然后,我们进行当多项式阶数 时,有 个采样点的情况下,λ较小和较大时(如 和 ) 时对过拟合现象的实验。
M=9
N=10
x_10, t_10 = get_data(N)
#lnλ = 0
p = regress(M, N, x, x_10, t_10, np.exp(0))
#图像绘制部分
plt.figure(8, figsize=(8,5))
plt.plot(x, t, 'g', x, p, 'r',linewidth=3)
plt.scatter(x_10, t_10, color='', marker='o', edgecolors='b', s=100, linewidth=3)
draw_ticks()
plt.text(0.8, 0.9,' lnλ = 0', font, style = 'italic')
plt.title('Figure 8 : M = 9, N = 10, lnλ = 0', font)
plt.savefig('8.png', dpi=400)
#lnλ = -18
p = regress(M, N, x, x_10, t_10, np.exp(-18))
#图像绘制部分
plt.figure(9, figsize=(8,5))
plt.plot(x, t, 'g', x, p, 'r',linewidth=3)
plt.scatter(x_10, t_10, color='', marker='o', edgecolors='b', s=100, linewidth=3)
draw_ticks()
plt.text(0.8, 0.9,' lnλ = -18', font, style = 'italic')
plt.title('Figure 9 : M = 9, N = 10, lnλ = -18', font)
plt.savefig('9.png', dpi=400)
结果显示,加上了正则项后,λ 较小时有效地改善了高阶多项式的过拟合现象,但是当 λ 过大时会过度抑制模型系数。所以,根据模型的复杂度来进行合适的正则化对于拟合结果非常重要。
实验总结
- 本次实验主要实现了多项式的曲线拟合。在拟合过程中,当模型的复杂度被限制而出现过拟合现象时,可以通过增加数据规模来进行改善;当数据有限时可以通过正则化的方法来抑制过拟合;也可以二者相结合得到更好的效果。
- 学习并实践了Python和Numpy的基本使用,尤其是矩阵的运算部分;学习利用matplotlib进行可视化。
继续加油嗷~ヾ(◍°∇°◍)ノ゙
ps:第一次用Markdown写博,还挺酷的哈哈哈~