《統計學習方法》代碼全解析——第一部分統計學習方法概論

1.統計學習是關於計算機基於數據構建概率統計模型並運用模型對數據進行分析與預測的一門學科。統計學習包括監督學習、非監督學習、半監督學習和強化學習。
2.統計學習方法三要素——模型、策略、算法,對理解統計學習方法起到提綱挈領的作用。
3.本書主要討論監督學習,監督學習可以概括如下:從給定有限的訓練數據出發, 假設數據是獨立同分布的,而且假設模型屬於某個假設空間,應用某一評價準則,從假設空間中選取一個最優的模型,使它對已給訓練數據及未知測試數據在給定評價標準意義下有最準確的預測。
4.統計學習中,進行模型選擇或者說提高學習的泛化能力是一個重要問題。如果只考慮減少訓練誤差,就可能產生過擬合現象。模型選擇的方法有正則化與交叉驗證。學習方法泛化能力的分析是統計學習理論研究的重要課題。
5.分類問題、標註問題和迴歸問題都是監督學習的重要問題。本書中介紹的統計學習方法包括感知機、 𝑘 近鄰法、樸素貝葉斯法、決策樹、邏輯斯諦迴歸與最大熵模型、支持向量機、提升方法、EM算法、隱馬爾可夫模型和條件隨機場。這些方法是主要的分類、標註以及迴歸方法。它們又可以歸類爲生成方法與判別方法。

使用最小二乘法擬和曲線

 舉例:我們用目標函數 𝑦=𝑠𝑖𝑛2𝜋𝑥 y=sin2πx , 加上一個正態分佈的噪音干擾,用多項式去擬合【例1.1 11頁】

import numpy as np
import scipy as sp
from scipy.optimize import leastsq
import matplotlib.pyplot as plt
%matplotlib inline

# 目標函數
def real_func(x):
    return np.sin(2*np.pi*x)

# 多項式
def fit_func(p, x):
    f = np.poly1d(p)
    return f(x)

# 殘差
def residuals_func(p, x, y):
    ret = fit_func(p, x) - y
    return ret

# 十個點
x = np.linspace(0, 1, 10)
x_points = np.linspace(0, 1, 1000)
# 加上正態分佈噪音的目標函數的值
y_ = real_func(x)
y = [np.random.normal(0, 0.1) + y1 for y1 in y_]


def fitting(M=0):
    """
    M    爲 多項式的次數
    """
    # 隨機初始化多項式參數
    p_init = np.random.rand(M + 1)
    # 最小二乘法
    p_lsq = leastsq(residuals_func, p_init, args=(x, y))
    print('Fitting Parameters:', p_lsq[0])

    # 可視化
    plt.plot(x_points, real_func(x_points), label='real')
    plt.plot(x_points, fit_func(p_lsq[0], x_points), label='fitted curve')
    plt.plot(x, y, 'bo', label='noise')
    plt.legend()
    return p_lsq

# M=0 
p_lsq_0 = fitting(M=0)

# M=1 
p_lsq_1 = fitting(M=1)

# M=9 
p_lsq_9 = fitting(M=9)

 Fitting Parameters: [-1.70872086e+04  7.01364939e+04 -1.18382087e+05  1.06032494e+05  -5.43222991e+04  1.60701108e+04 -2.65984526e+03  2.12318870e+02  -7.15931412e-02  3.53804263e-02]

 

regularization = 0.0001


def residuals_func_regularization(p, x, y):
    ret = fit_func(p, x) - y
    ret = np.append(ret,
                    np.sqrt(0.5 * regularization * np.square(p)))  # L2範數作爲正則化項
    return ret

# 最小二乘法,加正則化項
p_init = np.random.rand(9 + 1)
p_lsq_regularization = leastsq(
    residuals_func_regularization, p_init, args=(x, y))

plt.plot(x_points, real_func(x_points), label='real')
plt.plot(x_points, fit_func(p_lsq_9[0], x_points), label='fitted curve')
plt.plot(
    x_points,
    fit_func(p_lsq_regularization[0], x_points),
    label='regularization')
plt.plot(x, y, 'bo', label='noise')
plt.legend()

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章