最簡單的機器學習入門:線性迴歸

前言

線性函數用來做迴歸、做分類其實是數學內容應用與時間的一個簡單方法,其實這個高中生都可能會了解,只不過針對批樣本用到了矩陣,會涉及到一些線性代數內容。讓我們來了解一下這個數學背後的邏輯。

簡單的y=wx+b直線函數表達式

我們知道這是一關於x的直線函數,給出x的值就可以知道y是多少,思考這麼一個問題:

  • 正序邏輯:告訴你w、b的值,你就知道線性公式了,我們可以給出很多x的值,每個x值都有一個對應的y值
  • 反序邏輯:給你一些點<x,y>,你能告訴我w、b的值嘛?

在生活中都是反序邏輯,颱風跟季節、溫度、溼度、地理位置等的關係,我們知道往年的這些指標,你要探討這其中的關係,從而實現對未來的預測,根據關係找數值。
問題繼續,我們的目標明確一下,就是找反序邏輯,找到w、b的值。還是拿簡單的直線函數表達式來說,兩個點可以表示一個直線,我們就可以確定一個直線,就可以求解w、b的值,所以我們需要兩個樣本,比如:

  • 樣本1:x =1,y=3 , 樣本2:x =2,y=8,構建方程組:
    {w+b=32w+b=8 \left\{ \begin{array}{l} w+b=3 \\ 2w+b=8 \end{array} \right.
  • 高中知識就可以求解:w=5,b= - 2

所以直線就是y=5w2y=5w-2,這個是不是很簡單,那麼請迎接下面的問題。

思考:一條直線,100個<x,y>呢?

這不扯淡嘛,2個<x,y>就可以確定一個直線,你給我那麼多個<x,y>幹嘛?問題來了,你怎麼知道<x,y>就符合直線分佈,不符合曲線呢?好吧,你這樣講,那我們根本就不知道公式形式是怎麼樣的,還這麼求出表達式,那這樣,大家各退一步,我們就規定一條直線,來求這個直線,看怎麼把這個問題轉化一下,於是有人提出:

  • 找到一個直線,這條直線離各個<x,y>的距離之和最小。

於是我們把這個問題數學化就是:
minδ=i100yiyiyi=wxi+b min \qquad \delta = \sum_{i}^{100} |y_i-y^`_i| \\ y^`_i= wx_i +b
其中,δ\delta就是誤差函數,這個公式還是比較容易懂的吧,每一個x有對應的y,同時也有一個對應的yy`,可是yy`哪裏來的呢?這開始了線性函數的算法求解。

  • 1、初始化一個w、b,就是給一個默認值,比如w=12,b=2
  • 2、輸入x值,可以根據公式得到y=12x+2y`=12x+2,輸入100個<x,y>中的x,我們就可以得到就誤差函數δ\delta
  • 3、得到δ\delta有什麼用呢?我的目的是讓δ\delta最小化,所以需要根據誤差來做反向傳播,通過誤差來更新w、b的值
  • 4、反向傳播的一個利器來了:梯度下降法,參考博客:最簡單的講解:梯度下降,這裏會有一個步長參數。
  • 5、通過多次正向計算yy`,反向更新w、b值,最終我們的誤差就減少來,那麼什麼時候結束呢?我們設置一個閾值0.01,前後兩次誤差結果的差值小於這個閾值就結束,

這裏用前後兩次誤差值來做終止條件,說明更新w、b已經不能減少誤差,或者說對誤差的減少幫助甚小,所以可以提前終止掉,上述流程中的步長係數、誤差閾值都是可以調整的,按照自己的節奏。

思考:組成x的不是單獨一個數值,而是一個向量?

上一節的思考是從多樣本單維度的角度思考,現在我們思考多樣本、多維度的問題,在機器學習領域,提取特徵是我們經常做的事情,比如我們做房屋售價預測,需要房屋面積、位置、戶型、樓層等多維度的特徵,這些纔是組成一個樣本,也就是說大多數情況下,樣本是標示爲:<x1,x2,x3,...xn,y><x_1,x_2,x_3,...x_n,y>。特徵有n個,y只有一個,這種情況跟我們上述的討論有些差別,但我們知道多變量的線性表達式有如下形式:
y=w1x1+w2x2+w3x3+...+wnxn+b y=w_1x_1+w_2x_2+w_3x_3+...+w_nx_n+b
這是一條樣本的表示,通過線性代數的思想,如果我們有n個方程式,可以表示爲:
y1y2yn=x11x12x1mx21x22x2mxn1xn2xnmw1w2wm+bbb \begin{aligned} \begin{array} {|c|} y_{1} \\ y_{2}\\ \vdots&\\ y_{n} \end{array} = \begin{array} {|cccc|} x_{11}&x_{12}&\cdots&x_{1m}\\ x_{21}&x_{22}&\cdots&x_{2m}\\ \vdots&\cdots&\ddots&\vdots&\\ x_{n1}&x_{n2}&\cdots&x_{nm}\\ \end{array} \begin{array} {|c|} w_{1} \\ w_{2}\\ \vdots \\ w_{m}& \end{array} + \begin{array} {|c|} b \\ b\\ \vdots&\\ b \end{array} \end{aligned}
這樣把單個樣本的誤差提升到了多樣本,從矩陣來求解,每一次的更新都通過矩陣運算,所以公式可以改爲:
Y=WX+B Y=WX+B
所以在很多博客裏,求解都是通過上述極簡公式用於表述問題。

誤差函數、梯度下降的公式

我們知道里線性函數問題的求解問題形式是:
Y=WX+B Y=WX+B
用梯度下降法要求導,爲裏利於求導計算,去掉絕對值符號,利用均方誤差(MSE)替代:
1mi=1m(yiyi)2\frac{1}{m}\sum^{m}_{i=1}(y_{i} - y_{i}^`)^2

修改一下的誤差函數就是:
δ=1m(YY)2\delta = \frac{1}{m}(Y - Y^`)^2

其實有時候想想一下,前面的係數1m\frac{1}{m}到底有沒有用?我感覺沒用,因爲這個係數完全就是控制了大小,特別是求導的時候,我們看下,利用梯度下降法主要是有一個梯度方向,根據梯度方向,更新wbw、b,是的wbw、b的改變會讓誤差變小。那麼wbw、b的更新就變爲:
W=WλδW=W2λmXYY \begin{aligned} W &=W-\lambda \frac{\partial \delta}{\partial W} \\ &=W-\frac{2 \lambda }{m}X|Y-Y^`| \end{aligned}

b=bλδb=b2λmYY \begin{aligned} b&=b- \lambda \frac{\partial \delta}{\partial b} \\ &=b- \frac{2 \lambda }{m}|Y-Y^`| \end{aligned}
算法流程其他流程就跟上述的一樣。這裏我們一定要轉換思路,這其實是一個二元函數,其中二元是指w,bw,b,不再是XX。轉換角度來思考問題會讓你更深的理解。

最小二乘法

除了梯度下降法,其實還有最小二乘法,均方誤差的一個方面就是爲了最小二乘法,具體算法可以參考劉建平的算法博客最小二乘法
主要思想是對損失函數求偏導,令w、b偏倒數等於0,這樣損失函數最小,並將公式推廣到多樣本維度,得到求解公式,主要涉及到的損失函數、求解函數如下:
J(θ)=12(XθY)T(XθY)Wb=θ=(XTX)1XTY J(\theta)= \frac{1}{2}(X\theta-Y)^T(X\theta-Y)\\ \begin{aligned} \begin{array} {|c|} W\\ b \end{array} = \theta= (X^{T}X)^{-1}X^{T}Y \end{aligned}
其中,J(θ)J(\theta)主要是由向量公式:aTa=iai2a^Ta=\sum_{i}a_i^2演變而來,θ\theta中包含了WbW、b,其中XXRm(n+1)\mathbb{R}^{m*(n+1)}YYRm1\mathbb{R}^{m*1}WWRn1\mathbb{R}^{n*1}bbR1\mathbb{R}^{1}

  • 爲什麼XXRm(n+1)\mathbb{R}^{m*(n+1)}
    其實是第一個元素是1,這個數表示的是b這個位置。有些博客不會加1,依舊認爲XXRm(n)\mathbb{R}^{m*(n)},這是在公式裏把1作爲了一個特徵,融入進來,與b相乘,符合公式。

優點:

  • 全局最優解,其實不一定是全局最有解,有可能無解,但它還是能可以得到最優解。
  • 求解方便,公式直接求解。

缺點:

  • 受異常值擾動影響大,這可能是最大的影響。
  • 當樣本特徵n非常的大的時候,計算XTXX^TX的逆矩陣是一個非常耗時的工作(nxn的矩陣求逆),甚至不可行。此時以梯度下降爲代表的迭代法仍然可以使用。大數據環境下建議超過10000個特徵就用迭代法吧。或者通過主成分分析降低特徵的維度後再用最小二乘法,一般都用牛頓迭代法。
  • 計算XTXX^TX的逆矩陣,有可能它的逆矩陣不存在,解決方法就是去掉冗餘特徵,讓XTXX^TX的行列式不爲0,這種情況下,梯度下降法依舊有效。
  • 如果擬合函數不是線性的,這時無法使用最小二乘法,需要通過一些技巧轉化爲線性才能使用,此時梯度下降仍然可以用。

根據線性代數的知識:

  • 當樣本量m很少,小於特徵數n的時候,這時擬合方程是欠定的,常用的優化方法都無法去擬合數據。
  • 當樣本量m等於特徵數n的時候,用方程組求解就可以了。
  • 當m大於n時,擬合方程是超定的,也就是我們常用與最小二乘法的場景了。

損失函數

線性迴歸的損失函數可不只有均方誤差,還有很多如下的損失函數,

有時候爲了防止過擬合,一般會加一個L1L1L2L2範式。這裏不做展開,具體展開會有專門的文章講解。

社招、校招內推時刻

本人在阿里巴巴工作,業餘時間做了社招、校招的公衆號,可以內推大家,免篩選直接面試,公衆號的一些文章也幫助大學、研究生的一些同學瞭解校招、瞭解名企,工作幾年的同學想換工作也可以找我走社招內推,同時大家對文章有問題,也可以公衆號找我,掃碼關注哦!

參考博客

線性迴歸原理小結
機器學習者都應該知道的五種損失函數!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章