線性模型(二)之多項式擬合

1. 多項式擬合問題

  多項式擬合(polynominal curve fitting)是一種線性模型,模型和擬合參數的關係是線性的。多項式擬合的輸入是一維的,即x=x ,這是多項式擬合和線性迴歸問題的主要區別之一。

  多項式擬合的目標是構造輸入xM 階多項式函數,使得該多項式能夠近似表示輸入x 和輸出y 的關係,雖然實際上xy 的關係並不一定是多項式,但使用足夠多的階數,總是可以逼近表示輸入x 和輸出y 的關係的。

  多項式擬合問題的輸入可以表示如下:

D={(x1,y1),(x2,y2),...,(xi,yi),...,(xN,yN)}xiRyiR

  目標輸出是得到一個多項式函數:

f(x)=w1x1+w2x2+wixi+...+wMxM+b=(i=1Mwixi)+b

其中M 表示最高階數爲M

  可見在線性擬合的模型中,共包括了(M+1) 個參數,而該模型雖然不是輸入x 的線性函數,但卻是(M+1) 個擬合參數的線性函數,所以稱多項式擬合爲線性模型。對於多項式擬合問題,其實就是要確定這(M+1) 個參數,這裏先假設階數M 是固定的(M 是一個超參數,可以用驗證集來確定M 最優的值,詳細的關於M 值確定的問題,後面再討論),重點就在於如何求出這(M+1) 個參數的值。

2.優化目標

  多項式擬合是利用多項式函數逼近輸入x 和輸出y 的函數關係,通過什麼指標來衡量某個多項式函數的逼近程度呢?(其實這就是誤差/損失函數)。擬合/迴歸問題常用的評價指標是均方誤差(在機器學習中的模型評估與度量博客中,我進行了介紹)。多項式擬合問題也同樣採用該評價指標,以均方誤差作爲誤差/損失函數,誤差函數越小,模型越好。

E(w,b)=1Ni=1N[f(xi)yi]2

  係數1N 是一常數,對優化結果無影響,可以去除,即將均方誤差替換爲平方誤差:

E(w,b)=i=1N[f(xi)yi]2

   到這裏,就成功把多項式擬合問題變成了最優化問題,優化問題可表示爲:

argminw,bE(w,b)

即需要求得參數{w1,...,wM,b} 的值,使得E(w,b) 最小化。那麼如何對該最優化問題求解呢?

3. 優化問題求解

3.1 求偏導,聯立方程求解

   直觀的想法是,直接對所有參數求偏導,令偏導爲0,再聯立這M+1 個方程求解(因爲共有M+1 個參數,故求偏導後也是得到M+1 個方程)。

E(w,b)=i=1N[f(xi)yi]2=i=1N[(w1xi1+w2xi2+wixij+...+wMxiM+b)yi]2

利用E(w,b) 對各個參數求偏導,如下:

E(w,b)wj=2i=1N[(w1xi1+w2xi2+wixij+...+wMxiM+b)yi]xijE(w,b)b=2i=1N[(w1xi1+w2xi2+wixij+...+wMxiM+b)yi]

求導之後,將各個點(xi,yi) 的值帶入偏導公式,聯立方程求解即可。

  針對該解法,可以舉個例子詳細說明,比如有兩個點(2,3),(5,8) ,需要利用二階多項式f(x)=w1x+w2x2+b 擬合。求解過程如下:

  1. 該二階多項式對參數求偏導得到

    E(w,b)wj=2i=12[(w1xi1+w2xi2+b)yi]xij=[(w1x1+w2x12+b)y1]x1j+[(w1x2+w2x22+b)y2]x2jE(w,b)b=2i=12[(w1xi1+w2xi2+b)yi]=[(w1x1+w2x12+b)y1]+[(w1x2+w2x22+b)y2]
  2. 將點(2,3),(5,8) 帶入方程,可以得到3個方程,

    2b+7w1+29w2=117b+29w1+133w2=4629b+133w1+641w2=212
  3. 聯立這三個方程求解,發現有無窮多的解,只能得到3w1+21w2=5 ,這三個方程是線性相關的,故沒有唯一解。

  該方法通過求偏導,再聯立方程求解,比較複雜,看着也很不美觀。那麼有沒有更加方便的方法呢?

3.2 最小二乘法

   其實求解該最優化問題(平方和的最小值)一般會採用最小二乘法(其實最小二乘法和求偏導再聯立方程求解的方法無本質區別,求偏導也是最小二乘法,只是這裏介紹最小二乘的矩陣形式而已)。最小二乘法(least squares),從英文名非常容易想到,該方法就是求解平方和的最小值的方法。

  可以將誤差函數以矩陣的表示(N 個點,最高M 階)爲:

Xwy2

其中,把偏置b 融合到了參數w 中,

w={b,w1,w2,...,wM}

X 則表示輸入矩陣,

[1x1x12...x1M1x2x22...x2M...............1xNxN2...xNM]

y 則表示標註向量,

y={y1,y2,...,yN}T

因此,最優化問題可以重新表示爲

minwXwy2

對其求導,

Xwy2w=(Xwy)T(Xwy)w=(wTXTyT)(Xwy)w=(wTXTXwyTXwwTXTy+yTy)w

在繼續對其求導之前,需要先補充一些矩陣求導的先驗知識(常見的一些矩陣求導公式可以參見轉載的博客https://blog.csdn.net/lipengcn/article/details/52815429),如下:

xTax=aaxx=aTxTAx=Ax+ATx

根據上面的矩陣求導規則,繼續進行損失函數的求導

Xwy2w=(wTXTXwyTXwwTXTy+yTy)w=XTXw+(XTX)Tw(yTX)TXTy=2XTXw2XTy

其中XTXw=(XTX)Tw .令求導結果等於0,即可以求導問題的最小值。

2XTXw2XTy=0w=(XTX)1XTy

  再利用最小二乘法的矩陣形式對前面的例子進行求解,用二階多項式擬合即兩個點(2,3),(5,8)

  1. 表示輸入矩陣 X 和標籤向量y

    X=[1241525]y=[38]T
  2. 計算XTX

    XTX=[272972913329133641]
  3. 矩陣求逆,再做矩陣乘法運算
    XTX 不可逆,故無唯一解。

  關於矩陣的逆是否存在,可以通過判斷矩陣的行列式是否爲0(det(A)=?0 來判斷,也可以通過初等行變換,觀察矩陣的行向量是否線性相關,在這個例子下,矩陣不可逆,故有無窮多解。但如果新增一個點(4,7) ,則就可以解了。

  其實這和數據集的點數和選擇的階數有關,如果點數小於階數則會出現無窮解的情況,如果點數等於階數,那麼剛好有解可以完全擬合所有數據點,如果點數大於階數,則會求的近似解。

  那麼對於點數小於階數的情況,如何求解?在python的多項式擬合函數中是可以擬合的,而且效果不錯,具體算法不是很瞭解,可以想辦法參考python的ployfit()函數的實現。

4. 擬合階數的選擇

   在前面的推導中,多項式的階數被固定了,那麼實際場景下應該如何選擇合適的階數M 呢?

  1. 一般會選擇階數M 小於點數N
  2. 把訓練數據分爲訓練集合驗證集,在訓練集上,同時用不同的M 值訓練多個模型,然後選擇在驗證集誤差最小的階數M

5. 後續

  如果後續還想寫的話,可以考慮正則化問題。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章