《統計學習方法》第1章_統計學習方法概論

# encoding:utf-8
import numpy as np
from scipy.optimize import leastsq
import matplotlib.pyplot as plt

"""我們用目標函數y=sin2πx,加上一個正態分佈的噪音干擾,用多項式去擬合(例1.1 11頁)"""
"""目標函數"""


def real_func(x):
    return np.sin(2 * np.pi * x)


"""多項式"""


def fit_func(p, x):
    """生成一個多項式"""
    f = np.poly1d(p)
    return f(x)


"""殘差"""


def residuals_func(p, x, y):
    ret = fit_func(p, x) - y
    return ret


"""生成十個點"""
"""linspace()函數是創建等差數列的函數 """
"""第一個參數表示起始點、第二個參數表示終止點,第三個參數表示數列的個數"""
x = np.linspace(0, 1, 10)
x_points = np.linspace(0, 1, 100)

"""加上正態分佈噪音的目標函數的值"""
y_ = real_func(x)
y = [np.random.normal(0, 0.1) + y1 for y1 in y_]


def fitting(M=0):

    """M爲多項式的次數"""
    """隨機初始化多項式參數"""
    p_init = np.random.rand(M + 1)
    """最小二乘法"""
    """最小二乘函數leastsq()"""
    """
        residuals_func:誤差函數
        p_init:表示函數的參數
        args()表示數據點
    """
    p_lsq = leastsq(residuals_func, p_init, args=(x, y))
    print("Fitting Parameters:", p_lsq[0])

    """可視化"""
    plt.plot(x_points, real_func(x_points), label='real')
    plt.plot(x_points, fit_func(p_lsq[0], x_points), label='fitted curve')
    plt.plot(x, y, 'bo', label='noise')
    plt.legend()
    plt.show()
    return p_lsq


"""M=0"""
p_lsq_0 = fitting(M=0)

"""M=1"""
p_lsq_1 = fitting(M=1)

"""M=3"""
p_lsq_3 = fitting(M=3)

"""M=9"""
p_lsq_9 = fitting(M=9)


regularization = 0.0001


def residuals_func_regularization(p, x, y):
    ret = fit_func(p, x) - y
    """L2範數作爲正則化項"""
    ret = np.append(ret, np.sqrt(0.5 * regularization * np.square(p)))
    return ret


"""最小二乘法,加正則化項"""
p_init = np.random.rand(9 + 1)
p_lsq_regularization = leastsq(residuals_func_regularization, p_init, args=(x, y))
plt.plot(x_points, real_func(x_points), label="real")
plt.plot(x_points, fit_func(p_lsq_9[0], x_points), label="fitted curve")
plt.plot(x_points, fit_func(p_lsq_regularization[0], x_points), label="regularization")
plt.plot(x, y, "bo", label="noise")
plt.legend()
plt.show()
  • M=0

          Fitting Parameters: [ 0.06225187]

  • M=1

          Fitting Parameters: [-1.52289456  0.82369915]

  • M=3

          Fitting Parameters: [ 19.78479794 -29.56849209   9.62863747   0.15588463]

  • M=9

          Fitting Parameters: [  1.85650557e+04  -8.53696108e+04   1.66671311e+05 

                                             -1.79727475e+05 1.16312355e+05  -4.57586021e+04  

                                              1.05187764e+04  -1.27770807e+03  6.57051632e+01  

                                              1.94913510e-01]

  • 最小二乘法,加正則化項(L2範數作爲正則化項)

參考https://github.com/fengdu78/lihang-code/blob/master/%E7%AC%AC01%E7%AB%A0%20%E7%BB%9F%E8%AE%A1%E5%AD%A6%E4%B9%A0%E6%96%B9%E6%B3%95%E6%A6%82%E8%AE%BA/1.Introduction_to_statistical_learning_methods.ipynb

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章